pipeline_split_adapt_gpt

pull/9614/head
lichenever 4 years ago
parent 5a0ae06bb0
commit 8ecc188119

@ -38,6 +38,16 @@ namespace parallel {
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map; static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map;
static int send_tag = 0; static int send_tag = 0;
static int recv_tag = 0; static int recv_tag = 0;
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem};
static bool IsInWhiteList(const CNodePtr &cnode) {
for (auto &prim : WHITE_LIST) {
if (IsPrimitiveCNode(cnode, prim)) {
return true;
}
}
return false;
}
void PipelineTransformer::Coloring() { void PipelineTransformer::Coloring() {
auto need_coloring = true; auto need_coloring = true;
@ -85,7 +95,7 @@ void PipelineTransformer::BroadCastColoring() {
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) { if (IsInWhiteList(cnode)) {
return false; return false;
} }
if (IsInBlackList(prim)) { if (IsInBlackList(prim)) {
@ -138,42 +148,21 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
// handle send/recv a parameter
if (node->isa<Parameter>()) {
MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv.";
return std::make_pair(nullptr, nullptr);
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr op_info = nullptr; // Handle Cast and TupleGetitem situation
TensorInfo tensor_info; size_t tensor_info_index = 0;
// op1(stage1)->op2(stage2) if (IsPrimitiveCNode(cnode, prim::kPrimCast)) {
if (IsValueNode<Primitive>(cnode->input(0))) { cnode = cnode->input(1)->cast<CNodePtr>();
op_info = CreateOpInfo(cnode); } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
MS_EXCEPTION_IF_NULL(op_info); tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
tensor_info = op_info->outputs_tensor_info()[0]; cnode = cnode->input(1)->cast<CNodePtr>();
} else if (IsValueNode<FuncGraph>(cnode->input(0))) { }
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); // Create OperatorInfo to get slice_shape for send/recv
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(cnode);
auto output = graph->output(); auto op_info = CreateOpInfo(cnode);
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(op_info);
auto output_cnode = output->cast<CNodePtr>(); auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
MS_EXCEPTION_IF_NULL(output_cnode);
auto prim = GetValueNode<PrimitivePtr>(output_cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == TUPLE_GETITEM) {
auto index = GetTupleGetItemIndex(output_cnode);
auto pre_getitem_node = output_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_getitem_node);
op_info = CreateOpInfo(pre_getitem_node);
MS_EXCEPTION_IF_NULL(op_info);
tensor_info = op_info->outputs_tensor_info()[index];
} else {
op_info = CreateOpInfo(output_cnode);
MS_EXCEPTION_IF_NULL(op_info);
tensor_info = op_info->outputs_tensor_info()[0];
}
}
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info)); return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
} }
@ -316,6 +305,29 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
return std::make_pair(shape_list, dtype); return std::make_pair(shape_list, dtype);
} }
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsValueNode<FuncGraph>(cnode->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
auto output = graph->output();
MS_EXCEPTION_IF_NULL(output);
if (output->isa<Parameter>()) {
return output;
}
cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
}
if (IsInWhiteList(cnode)) {
return cnode->cast<AnfNodePtr>();
}
if (!IsPipelineCareNode(cnode)) {
MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border.";
}
return cnode->cast<AnfNodePtr>();
}
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int user_node_stage, SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int user_node_stage,
int node_stage) { int node_stage) {
Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag));
@ -330,7 +342,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
if (parameter->isa<Parameter>()) { if (parameter->isa<Parameter>()) {
op_info_pair = GetParameterPair(parameter); op_info_pair = GetParameterPair(parameter);
} else { } else {
op_info_pair = GetOpInfo(parameter); auto care_node = FindPipelineCareNode(parameter);
if (care_node->isa<Parameter>()) {
op_info_pair = GetParameterPair(care_node);
} else {
op_info_pair = GetOpInfo(care_node);
}
} }
auto tensor_info = op_info_pair.second; auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info); MS_EXCEPTION_IF_NULL(tensor_info);
@ -360,7 +377,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
op_info_pair = GetParameterPair(node); op_info_pair = GetParameterPair(node);
} else { } else {
op_info_pair = GetOpInfo(node); auto care_node = FindPipelineCareNode(node);
if (care_node->isa<Parameter>()) {
op_info_pair = GetParameterPair(care_node);
} else {
op_info_pair = GetOpInfo(care_node);
}
} }
auto tensor_info = op_info_pair.second; auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info); MS_EXCEPTION_IF_NULL(tensor_info);

@ -64,6 +64,7 @@ class PipelineTransformer {
int user_node_stage, int node_stage); int user_node_stage, int node_stage);
void CutBorder(const FuncGraphPtr &graph); void CutBorder(const FuncGraphPtr &graph);
bool IsStageNode(const CNodePtr &node); bool IsStageNode(const CNodePtr &node);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node); std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node); std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);

Loading…
Cancel
Save