optimize splitsort

pull/1817/head
kswang 5 years ago
parent 5c4731b772
commit 472d87fee1

@ -65,8 +65,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
GraphId graph_id = kInvalidGraphId;
if (target == kCPUDevice) {
graph_id = cpu_sess_->CompileGraph(lst, outputs);
if (target != target_device_ && target != "") {
CreateOtherSession(target);
graph_id = other_sess_->CompileGraph(lst, outputs);
} else {
graph_id = target_sess_->CompileGraph(lst, outputs);
}
@ -75,8 +76,8 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
}
if (target == kCPUDevice) {
cpu_sess_->BuildGraph(graph_id);
if (target != target_device_ && target != "") {
other_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
}
@ -278,8 +279,8 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
if (target == kCPUDevice) {
cpu_sess_->RunGraph(g, inputs, &outputs);
if (target != target_device_ && target != "") {
other_sess_->RunGraph(g, inputs, &outputs);
} else {
target_sess_->RunGraph(g, inputs, &outputs);
}
@ -341,16 +342,20 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
}
target_sess_->Init(device_id);
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
if (target == kCPUDevice) {
cpu_sess_ = target_sess_;
} else {
cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice);
if (cpu_sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << ".";
}
cpu_sess_->Init(0);
cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
target_device_ = target;
}
void MsBackend::CreateOtherSession(const std::string &target) {
if (other_sess_ != nullptr && other_device_ == target) {
return;
}
other_sess_ = session::SessionFactory::Get().Create(kCPUDevice);
if (other_sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
}
other_sess_->Init(0);
other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
other_device_ = target;
}
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }

@ -107,10 +107,13 @@ class MsBackend : public Backend {
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
void CreateOtherSession(const std::string &target);
private:
session::SessionPtr target_sess_;
session::SessionPtr cpu_sess_;
session::SessionPtr other_sess_;
std::string target_device_;
std::string other_device_;
std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_;
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_;

@ -21,6 +21,7 @@
#include <algorithm>
#include <map>
#include <queue>
#include <stack>
#include <set>
#include <string>
#include <vector>
@ -75,7 +76,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
ValuePtr att_target = primitive->GetAttr("target");
ValuePtr att_target = primitive->GetAttr("primitive_target");
if (att_target != nullptr) {
std::string target = GetValue<std::string>(att_target);
return target;
@ -127,6 +128,58 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
}
}
}
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> to_visit;
std::stack<AnfNodePtr> next_to_visit;
std::map<AnfNodePtr, size_t> nodes_ref;
CalcNodeRefCount(graph, &nodes_ref);
std::string handle_target = default_target;
std::string next_target = "";
to_visit.push(graph->get_return());
while (!to_visit.empty() || !next_to_visit.empty()) {
if (to_visit.empty()) {
to_visit.swap(next_to_visit);
handle_target = next_target;
}
auto &node = to_visit.top();
to_visit.pop();
MS_EXCEPTION_IF_NULL(node);
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
std::reverse(node_inputs.begin(), node_inputs.end());
for (auto &input : node_inputs) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
to_visit.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == handle_target) {
to_visit.push(input);
} else if (next_to_visit.empty() || input_target == next_target) {
next_to_visit.push(input);
next_target = input_target;
} else {
MS_LOG(EXCEPTION) << "only support two different target";
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
} // namespace
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
@ -180,65 +233,16 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
return false;
}
std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> result;
std::queue<AnfNodePtr> queue;
std::queue<AnfNodePtr> next_queue;
std::map<AnfNodePtr, size_t> nodes_ref;
CalcNodeRefCount(graph, &nodes_ref);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string queue_target = context_ptr->device_target();
std::string next_target = "";
queue.push(graph->get_return());
while (!queue.empty() || !next_queue.empty()) {
if (queue.empty()) {
queue.swap(next_queue);
queue_target = next_target;
}
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
queue.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == queue_target) {
queue.push(input);
} else if (next_queue.empty() || input_target == next_target) {
next_queue.push(input);
next_target = input_target;
} else {
MS_LOG(EXCEPTION) << "only support two different target";
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
VectorRef splits;
VectorRef split;
auto nodes = TopoSort(graph->get_return());
if (ContainMultiTarget(nodes)) {
nodes = SplitSort(graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
nodes = SplitSort(graph, default_target);
}
std::string last_target;
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();

@ -79,7 +79,6 @@ class CompileGraph {
private:
void PushParameters(const FuncGraphPtr &func_graph);
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph);
bool SplitGraph(const FuncGraphPtr &func_graph);
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);

Loading…
Cancel
Save