|
|
@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
|
|
|
|
multi_result_.inputs = g->parameters();
|
|
|
|
multi_result_.inputs = g->parameters();
|
|
|
|
final_output_ = NewValueNode("fake_output");
|
|
|
|
final_output_ = NewValueNode("fake_output");
|
|
|
|
multi_result_.outputs = {final_output_};
|
|
|
|
multi_result_.outputs = {final_output_};
|
|
|
|
GraphId final_g = sess_->GetFinalRunGraph();
|
|
|
|
GraphId final_g = target_sess_->GetFinalRunGraph();
|
|
|
|
|
|
|
|
|
|
|
|
multi_result_.run = std::make_shared<RunFunc>(
|
|
|
|
multi_result_.run = std::make_shared<RunFunc>(
|
|
|
|
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); });
|
|
|
|
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
|
|
|
|
return multi_result_;
|
|
|
|
return multi_result_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
|
|
|
|
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
|
|
|
|
MS_LOG(DEBUG) << "MsConvert";
|
|
|
|
MS_LOG(DEBUG) << "MsConvert";
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
auto cached = g_ConvertCache.find(lst);
|
|
|
|
auto cached = g_ConvertCache.find(lst);
|
|
|
@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
|
|
|
|
result.inputs = inputs;
|
|
|
|
result.inputs = inputs;
|
|
|
|
result.outputs = outputs;
|
|
|
|
result.outputs = outputs;
|
|
|
|
result.graph_id = kInvalidGraphId;
|
|
|
|
result.graph_id = kInvalidGraphId;
|
|
|
|
auto graph_id = sess_->CompileGraph(lst, outputs);
|
|
|
|
GraphId graph_id = kInvalidGraphId;
|
|
|
|
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
|
|
|
if (target == kCPUDevice) {
|
|
|
|
sess_->BuildGraph(graph_id);
|
|
|
|
graph_id = cpu_sess_->CompileGraph(lst, outputs);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
graph_id = target_sess_->CompileGraph(lst, outputs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (MsContext::GetInstance()->precompile_only()) {
|
|
|
|
if (MsContext::GetInstance()->precompile_only()) {
|
|
|
|
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
|
|
|
|
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
|
|
|
|
return result;
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (target == kCPUDevice) {
|
|
|
|
|
|
|
|
cpu_sess_->BuildGraph(graph_id);
|
|
|
|
|
|
|
|
} else if (!is_multi_graph_sink_) {
|
|
|
|
|
|
|
|
target_sess_->BuildGraph(graph_id);
|
|
|
|
|
|
|
|
}
|
|
|
|
result.run = std::make_shared<RunFunc>(
|
|
|
|
result.run = std::make_shared<RunFunc>(
|
|
|
|
[graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); });
|
|
|
|
[graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
|
|
|
|
MS_EXCEPTION_IF_NULL(result.run);
|
|
|
|
MS_EXCEPTION_IF_NULL(result.run);
|
|
|
|
|
|
|
|
|
|
|
|
result.simu_run = std::make_shared<RunFunc>(
|
|
|
|
result.simu_run = std::make_shared<RunFunc>(
|
|
|
@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
|
|
|
|
|
|
|
|
|
|
|
|
GraphId cond_g = kInvalidGraphId;
|
|
|
|
GraphId cond_g = kInvalidGraphId;
|
|
|
|
if (utils::isa<AnfNodePtr>(c)) {
|
|
|
|
if (utils::isa<AnfNodePtr>(c)) {
|
|
|
|
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
|
|
|
|
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
|
|
|
|
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
|
|
|
|
MS_LOG(DEBUG) << "invoke set active:" << active_g;
|
|
|
|
MS_LOG(DEBUG) << "invoke set active:" << active_g;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
|
|
|
|
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
|
|
|
|
sess_->SetActive(active_g, cond_g);
|
|
|
|
target_sess_->SetActive(active_g, cond_g);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MsBackend::SetSwitchGraph() {
|
|
|
|
void MsBackend::SetSwitchGraph() {
|
|
|
@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
GraphId cond_g = kInvalidGraphId;
|
|
|
|
GraphId cond_g = kInvalidGraphId;
|
|
|
|
if (utils::isa<AnfNodePtr>(curr_switch_)) {
|
|
|
|
if (utils::isa<AnfNodePtr>(curr_switch_)) {
|
|
|
|
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
|
|
|
|
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
|
|
|
|
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
|
|
|
|
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
|
|
|
|
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
|
|
|
|
target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
is_switch_call_ = false;
|
|
|
|
is_switch_call_ = false;
|
|
|
|
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
|
|
|
|
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
|
|
|
@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef
|
|
|
|
old_args[i] = args[it->second];
|
|
|
|
old_args[i] = args[it->second];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
sess_->SetChildGraphInput(graph, old_args);
|
|
|
|
target_sess_->SetChildGraphInput(graph, old_args);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
graph_inputs_.erase(c);
|
|
|
|
graph_inputs_.erase(c);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef
|
|
|
|
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
|
|
|
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
|
|
|
MS_LOG(DEBUG) << "set graph input:" << g;
|
|
|
|
MS_LOG(DEBUG) << "set graph input:" << g;
|
|
|
|
// switch maybe twice
|
|
|
|
// switch maybe twice
|
|
|
|
sess_->SetChildGraphInput(g, args);
|
|
|
|
target_sess_->SetChildGraphInput(g, args);
|
|
|
|
|
|
|
|
|
|
|
|
if (is_switch_call_) {
|
|
|
|
if (is_switch_call_) {
|
|
|
|
if (!curr_switch_.is_null()) {
|
|
|
|
if (!curr_switch_.is_null()) {
|
|
|
@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
|
|
|
return VectorRef(outputs);
|
|
|
|
return VectorRef(outputs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
|
|
|
|
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
|
|
|
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
|
|
|
|
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
|
|
|
|
// Run graph
|
|
|
|
// Run graph
|
|
|
|
std::vector<tensor::TensorPtr> inputs;
|
|
|
|
std::vector<tensor::TensorPtr> inputs;
|
|
|
@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
|
|
|
|
|
|
|
|
|
|
|
|
VectorRef outputs;
|
|
|
|
VectorRef outputs;
|
|
|
|
// call ms rungraph (graphId, input ,output)
|
|
|
|
// call ms rungraph (graphId, input ,output)
|
|
|
|
sess_->RunGraph(g, inputs, &outputs);
|
|
|
|
if (target == kCPUDevice) {
|
|
|
|
|
|
|
|
cpu_sess_->RunGraph(g, inputs, &outputs);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
target_sess_->RunGraph(g, inputs, &outputs);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
|
|
|
|
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
|
|
|
|
return outputs;
|
|
|
|
return outputs;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
|
|
|
|
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
|
|
|
|
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
|
|
|
|
[](const AnfNodePtr &v) { return v; });
|
|
|
|
[](const AnfNodePtr &v) { return v; });
|
|
|
|
MS_LOG(DEBUG) << "Simulate start";
|
|
|
|
MS_LOG(DEBUG) << "Simulate start";
|
|
|
|
(void)sess_->SetFinalGraphInput(parameters);
|
|
|
|
(void)target_sess_->SetFinalGraphInput(parameters);
|
|
|
|
BaseRef output = rt->Eval(VectorRef(args));
|
|
|
|
BaseRef output = rt->Eval(VectorRef(args));
|
|
|
|
sess_->SetFinalGraphOutput(output);
|
|
|
|
target_sess_->SetFinalGraphOutput(output);
|
|
|
|
MS_LOG(DEBUG) << "Simulate Eval end";
|
|
|
|
MS_LOG(DEBUG) << "Simulate Eval end";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MsBackend::Link(GraphId graph_id) {
|
|
|
|
void MsBackend::Link(GraphId graph_id) {
|
|
|
|
if (graph_id == kInvalidGraphId) {
|
|
|
|
if (graph_id == kInvalidGraphId) {
|
|
|
|
graph_id = sess_->GetFinalRunGraph();
|
|
|
|
graph_id = target_sess_->GetFinalRunGraph();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
sess_->BuildGraph(graph_id);
|
|
|
|
target_sess_->BuildGraph(graph_id);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Backend::Backend(const std::string &name) : name_(name) {
|
|
|
|
Backend::Backend(const std::string &name) : name_(name) {
|
|
|
@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
|
|
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
|
|
|
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1);
|
|
|
|
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
|
|
|
|
sess_ = session::SessionFactory::Get().Create(target);
|
|
|
|
target_sess_ = session::SessionFactory::Get().Create(target);
|
|
|
|
if (sess_ == nullptr) {
|
|
|
|
if (target_sess_ == nullptr) {
|
|
|
|
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
|
|
|
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
sess_->Init(device_id);
|
|
|
|
target_sess_->Init(device_id);
|
|
|
|
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
|
|
|
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
|
|
|
|
|
|
|
if (target == kCPUDevice) {
|
|
|
|
|
|
|
|
cpu_sess_ = target_sess_;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice);
|
|
|
|
|
|
|
|
if (cpu_sess_ == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << ".";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
cpu_sess_->Init(0);
|
|
|
|
|
|
|
|
cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); }
|
|
|
|
GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
|
|
|
|
|
|
|
|
|
|
|
|
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
|
|
|
|
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
|
|
|
|
|
|
|
|
|
|
|
|