|
|
|
@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsGetItemNode(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|
if (inputs.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
|
|
|
|
}
|
|
|
|
|
if (!IsValueNode<Primitive>(inputs[0])) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
|
|
|
|
|
return node_prim->name() == prim::kPrimTupleGetItem->name();
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
|
|
|
|
|
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
|
|
|
|
|
std::vector<AnfNodePtr> result;
|
|
|
|
|
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
|
|
|
|
|
std::map<AnfNodePtr, size_t> node_positions;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (IsGetItemNode(node)) {
|
|
|
|
|
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
@ -241,7 +224,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::reverse(result.begin(), result.end());
|
|
|
|
|
return ReorderGetItemNode(result);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|
|
|
|
VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto nodes = OptimizeGetItemOrder(input_nodes);
|
|
|
|
|
VectorRef splits;
|
|
|
|
|
VectorRef split;
|
|
|
|
|
auto nodes = TopoSort(graph->get_return());
|
|
|
|
|
if (ContainMultiTarget(nodes)) {
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
std::string default_target = context_ptr->device_target();
|
|
|
|
|
nodes = SplitSort(graph, default_target);
|
|
|
|
|
}
|
|
|
|
|
std::string last_target;
|
|
|
|
|
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (IsCut(node)) {
|
|
|
|
@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|
|
|
|
return splits;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto nodes = TopoSort(graph->get_return());
|
|
|
|
|
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
|
|
|
|
|
|
|
|
|
if (ContainMultiTarget(nodes)) {
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
std::string default_target = context_ptr->device_target();
|
|
|
|
|
nodes = SplitSort(graph, default_target);
|
|
|
|
|
return SplitNodesWithTarget(nodes, graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VectorRef splits;
|
|
|
|
|
VectorRef split;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (IsCut(node)) {
|
|
|
|
|
if (split.size() != 0) {
|
|
|
|
|
splits.push_back(split);
|
|
|
|
|
}
|
|
|
|
|
splits.push_back(node);
|
|
|
|
|
split.clear();
|
|
|
|
|
} else if (node->isa<CNode>()) {
|
|
|
|
|
split.push_back(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return splits;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Push the value node on the stack.
|
|
|
|
|
void CompileGraph::Push(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|