|
|
|
@ -38,6 +38,16 @@ namespace parallel {
|
|
|
|
|
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map;
|
|
|
|
|
static int send_tag = 0;
|
|
|
|
|
static int recv_tag = 0;
|
|
|
|
|
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem};
|
|
|
|
|
|
|
|
|
|
static bool IsInWhiteList(const CNodePtr &cnode) {
|
|
|
|
|
for (auto &prim : WHITE_LIST) {
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PipelineTransformer::Coloring() {
|
|
|
|
|
auto need_coloring = true;
|
|
|
|
@ -85,7 +95,7 @@ void PipelineTransformer::BroadCastColoring() {
|
|
|
|
|
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
if (IsInWhiteList(cnode)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (IsInBlackList(prim)) {
|
|
|
|
@ -138,42 +148,21 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
|
|
|
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// handle send/recv a parameter
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv.";
|
|
|
|
|
return std::make_pair(nullptr, nullptr);
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
OperatorInfoPtr op_info = nullptr;
|
|
|
|
|
TensorInfo tensor_info;
|
|
|
|
|
// op1(stage1)->op2(stage2)
|
|
|
|
|
if (IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
op_info = CreateOpInfo(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
tensor_info = op_info->outputs_tensor_info()[0];
|
|
|
|
|
} else if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
|
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto output = graph->output();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
auto output_cnode = output->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode);
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(output_cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim->name() == TUPLE_GETITEM) {
|
|
|
|
|
auto index = GetTupleGetItemIndex(output_cnode);
|
|
|
|
|
auto pre_getitem_node = output_cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_getitem_node);
|
|
|
|
|
op_info = CreateOpInfo(pre_getitem_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
tensor_info = op_info->outputs_tensor_info()[index];
|
|
|
|
|
} else {
|
|
|
|
|
op_info = CreateOpInfo(output_cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
tensor_info = op_info->outputs_tensor_info()[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Handle Cast and TupleGetitem situation
|
|
|
|
|
size_t tensor_info_index = 0;
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) {
|
|
|
|
|
cnode = cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
|
|
|
|
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
|
|
|
|
|
cnode = cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
// Create OperatorInfo to get slice_shape for send/recv
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto op_info = CreateOpInfo(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
|
|
|
|
|
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -316,6 +305,29 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
|
|
|
|
|
return std::make_pair(shape_list, dtype);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
|
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
|
|
|
|
auto output = graph->output();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
if (output->isa<Parameter>()) {
|
|
|
|
|
return output;
|
|
|
|
|
}
|
|
|
|
|
cnode = output->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
}
|
|
|
|
|
if (IsInWhiteList(cnode)) {
|
|
|
|
|
return cnode->cast<AnfNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
if (!IsPipelineCareNode(cnode)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border.";
|
|
|
|
|
}
|
|
|
|
|
return cnode->cast<AnfNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage,
|
|
|
|
|
int node_stage) {
|
|
|
|
|
Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag));
|
|
|
|
@ -330,7 +342,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|
|
|
|
if (parameter->isa<Parameter>()) {
|
|
|
|
|
op_info_pair = GetParameterPair(parameter);
|
|
|
|
|
} else {
|
|
|
|
|
op_info_pair = GetOpInfo(parameter);
|
|
|
|
|
auto care_node = FindPipelineCareNode(parameter);
|
|
|
|
|
if (care_node->isa<Parameter>()) {
|
|
|
|
|
op_info_pair = GetParameterPair(care_node);
|
|
|
|
|
} else {
|
|
|
|
|
op_info_pair = GetOpInfo(care_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto tensor_info = op_info_pair.second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
|
|
|
@ -360,7 +377,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
op_info_pair = GetParameterPair(node);
|
|
|
|
|
} else {
|
|
|
|
|
op_info_pair = GetOpInfo(node);
|
|
|
|
|
auto care_node = FindPipelineCareNode(node);
|
|
|
|
|
if (care_node->isa<Parameter>()) {
|
|
|
|
|
op_info_pair = GetParameterPair(care_node);
|
|
|
|
|
} else {
|
|
|
|
|
op_info_pair = GetOpInfo(care_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto tensor_info = op_info_pair.second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
|
|
|
|