diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index c9252bf08c..dbe2983cac 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -38,6 +38,16 @@ namespace parallel { static std::unordered_map> parameter_color_map; static int send_tag = 0; static int recv_tag = 0; +const std::set WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem}; + +static bool IsInWhiteList(const CNodePtr &cnode) { + for (auto &prim : WHITE_LIST) { + if (IsPrimitiveCNode(cnode, prim)) { + return true; + } + } + return false; +} void PipelineTransformer::Coloring() { auto need_coloring = true; @@ -85,7 +95,7 @@ void PipelineTransformer::BroadCastColoring() { bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); auto prim = GetValueNode(cnode->input(0)); - if (prim == nullptr) { + if (IsInWhiteList(cnode)) { return false; } if (IsInBlackList(prim)) { @@ -138,42 +148,21 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { std::pair PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - // handle send/recv a parameter - if (node->isa()) { - MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv."; - return std::make_pair(nullptr, nullptr); - } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr op_info = nullptr; - TensorInfo tensor_info; - // op1(stage1)->op2(stage2) - if (IsValueNode(cnode->input(0))) { - op_info = CreateOpInfo(cnode); - MS_EXCEPTION_IF_NULL(op_info); - tensor_info = op_info->outputs_tensor_info()[0]; - } else if (IsValueNode(cnode->input(0))) { - auto graph = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(graph); - auto output = graph->output(); - MS_EXCEPTION_IF_NULL(output); - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto prim = GetValueNode(output_cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == TUPLE_GETITEM) { - auto index = GetTupleGetItemIndex(output_cnode); - auto pre_getitem_node = output_cnode->input(1)->cast(); - MS_EXCEPTION_IF_NULL(pre_getitem_node); - op_info = CreateOpInfo(pre_getitem_node); - MS_EXCEPTION_IF_NULL(op_info); - tensor_info = op_info->outputs_tensor_info()[index]; - } else { - op_info = CreateOpInfo(output_cnode); - MS_EXCEPTION_IF_NULL(op_info); - tensor_info = op_info->outputs_tensor_info()[0]; - } - } + // Handle Cast and TupleGetitem situation + size_t tensor_info_index = 0; + if (IsPrimitiveCNode(cnode, prim::kPrimCast)) { + cnode = cnode->input(1)->cast(); + } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { + tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode)); + cnode = cnode->input(1)->cast(); + } + // Create OperatorInfo to get slice_shape for send/recv + MS_EXCEPTION_IF_NULL(cnode); + auto op_info = CreateOpInfo(cnode); + MS_EXCEPTION_IF_NULL(op_info); + auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index]; return std::make_pair(op_info, std::make_shared(tensor_info)); } @@ -316,6 +305,29 @@ static std::pair GetShapeType(const AnfNodePtr &node, con return std::make_pair(shape_list, dtype); } +AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsValueNode(cnode->input(0))) { + auto graph = GetValueNode(cnode->input(0)); + auto output = graph->output(); + MS_EXCEPTION_IF_NULL(output); + if (output->isa()) { + return output; + } + cnode = output->cast(); + MS_EXCEPTION_IF_NULL(cnode); + } + if (IsInWhiteList(cnode)) { + return cnode->cast(); + } + if (!IsPipelineCareNode(cnode)) { + MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."; + } + return cnode->cast(); +} + SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, int node_stage) { Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); @@ -330,7 +342,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod if (parameter->isa()) { op_info_pair = GetParameterPair(parameter); } else { - op_info_pair = GetOpInfo(parameter); + auto care_node = FindPipelineCareNode(parameter); + if (care_node->isa()) { + op_info_pair = GetParameterPair(care_node); + } else { + op_info_pair = GetOpInfo(care_node); + } } auto tensor_info = op_info_pair.second; MS_EXCEPTION_IF_NULL(tensor_info); @@ -360,7 +377,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode if (node->isa()) { op_info_pair = GetParameterPair(node); } else { - op_info_pair = GetOpInfo(node); + auto care_node = FindPipelineCareNode(node); + if (care_node->isa()) { + op_info_pair = GetParameterPair(care_node); + } else { + op_info_pair = GetOpInfo(care_node); + } } auto tensor_info = op_info_pair.second; MS_EXCEPTION_IF_NULL(tensor_info); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index 10290128d3..fcd40fa43e 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -64,6 +64,7 @@ class PipelineTransformer { int user_node_stage, int node_stage); void CutBorder(const FuncGraphPtr &graph); bool IsStageNode(const CNodePtr &node); + AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); std::pair GetOpInfo(const AnfNodePtr &node); std::pair GetParameterPair(const AnfNodePtr &node); OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);