|
|
|
@ -21,6 +21,7 @@
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <stack>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -75,7 +76,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
|
|
|
|
|
return default_target;
|
|
|
|
|
}
|
|
|
|
|
auto primitive = value->cast<PrimitivePtr>();
|
|
|
|
|
ValuePtr att_target = primitive->GetAttr("target");
|
|
|
|
|
ValuePtr att_target = primitive->GetAttr("primitive_target");
|
|
|
|
|
if (att_target != nullptr) {
|
|
|
|
|
std::string target = GetValue<std::string>(att_target);
|
|
|
|
|
return target;
|
|
|
|
@ -127,6 +128,58 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
|
|
|
|
|
std::vector<AnfNodePtr> result;
|
|
|
|
|
std::stack<AnfNodePtr> to_visit;
|
|
|
|
|
std::stack<AnfNodePtr> next_to_visit;
|
|
|
|
|
std::map<AnfNodePtr, size_t> nodes_ref;
|
|
|
|
|
CalcNodeRefCount(graph, &nodes_ref);
|
|
|
|
|
std::string handle_target = default_target;
|
|
|
|
|
std::string next_target = "";
|
|
|
|
|
to_visit.push(graph->get_return());
|
|
|
|
|
while (!to_visit.empty() || !next_to_visit.empty()) {
|
|
|
|
|
if (to_visit.empty()) {
|
|
|
|
|
to_visit.swap(next_to_visit);
|
|
|
|
|
handle_target = next_target;
|
|
|
|
|
}
|
|
|
|
|
auto &node = to_visit.top();
|
|
|
|
|
to_visit.pop();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
result.emplace_back(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto node_inputs = cnode->inputs();
|
|
|
|
|
std::reverse(node_inputs.begin(), node_inputs.end());
|
|
|
|
|
for (auto &input : node_inputs) {
|
|
|
|
|
auto iter = nodes_ref.find(input);
|
|
|
|
|
if (iter != nodes_ref.end()) {
|
|
|
|
|
iter->second--;
|
|
|
|
|
if (iter->second != 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!input->isa<CNode>()) {
|
|
|
|
|
to_visit.push(input);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::string input_target = GetCNodeTarget(input);
|
|
|
|
|
if (input_target == handle_target) {
|
|
|
|
|
to_visit.push(input);
|
|
|
|
|
} else if (next_to_visit.empty() || input_target == next_target) {
|
|
|
|
|
next_to_visit.push(input);
|
|
|
|
|
next_target = input_target;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only support two different target";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::reverse(result.begin(), result.end());
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
|
|
|
|
@ -180,65 +233,16 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) {
|
|
|
|
|
std::vector<AnfNodePtr> result;
|
|
|
|
|
std::queue<AnfNodePtr> queue;
|
|
|
|
|
std::queue<AnfNodePtr> next_queue;
|
|
|
|
|
std::map<AnfNodePtr, size_t> nodes_ref;
|
|
|
|
|
CalcNodeRefCount(graph, &nodes_ref);
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
std::string queue_target = context_ptr->device_target();
|
|
|
|
|
std::string next_target = "";
|
|
|
|
|
queue.push(graph->get_return());
|
|
|
|
|
while (!queue.empty() || !next_queue.empty()) {
|
|
|
|
|
if (queue.empty()) {
|
|
|
|
|
queue.swap(next_queue);
|
|
|
|
|
queue_target = next_target;
|
|
|
|
|
}
|
|
|
|
|
auto &node = queue.front();
|
|
|
|
|
queue.pop();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
result.emplace_back(node);
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
for (auto &input : cnode->inputs()) {
|
|
|
|
|
auto iter = nodes_ref.find(input);
|
|
|
|
|
if (iter != nodes_ref.end()) {
|
|
|
|
|
iter->second--;
|
|
|
|
|
if (iter->second != 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!input->isa<CNode>()) {
|
|
|
|
|
queue.push(input);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::string input_target = GetCNodeTarget(input);
|
|
|
|
|
if (input_target == queue_target) {
|
|
|
|
|
queue.push(input);
|
|
|
|
|
} else if (next_queue.empty() || input_target == next_target) {
|
|
|
|
|
next_queue.push(input);
|
|
|
|
|
next_target = input_target;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "only support two different target";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::reverse(result.begin(), result.end());
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
VectorRef splits;
|
|
|
|
|
VectorRef split;
|
|
|
|
|
auto nodes = TopoSort(graph->get_return());
|
|
|
|
|
if (ContainMultiTarget(nodes)) {
|
|
|
|
|
nodes = SplitSort(graph);
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
std::string default_target = context_ptr->device_target();
|
|
|
|
|
nodes = SplitSort(graph, default_target);
|
|
|
|
|
}
|
|
|
|
|
std::string last_target;
|
|
|
|
|
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
|
|
|
|