You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
332 lines
12 KiB
332 lines
12 KiB
/**
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "vm/backend.h"
|
|
|
|
#include <algorithm>
|
|
#include <vector>
|
|
|
|
#include "utils/log_adapter.h"
|
|
#include "ir/anf.h"
|
|
#include "utils/callbacks.h"
|
|
#include "utils/graph_utils.h"
|
|
#include "session/session_factory.h"
|
|
#include "common/utils.h"
|
|
#ifdef ENABLE_GE
|
|
#include "utils/callbacks_ge.h"
|
|
#endif
|
|
|
|
namespace mindspore {
|
|
namespace compile {
|
|
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
|
|
|
|
LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
|
|
// multi_graph merge to one, big graph have paramters in begin and only have one output
|
|
MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
|
|
multi_result_.inputs = g->parameters();
|
|
final_output_ = NewValueNode("fake_output");
|
|
multi_result_.outputs = {final_output_};
|
|
GraphId final_g = sess_->GetFinalRunGraph();
|
|
|
|
multi_result_.run = std::make_shared<RunFunc>(
|
|
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); });
|
|
return multi_result_;
|
|
}
|
|
|
|
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
|
|
MS_LOG(DEBUG) << "MsConvert";
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
auto cached = g_ConvertCache.find(lst);
|
|
if (cached != g_ConvertCache.end()) {
|
|
return cached->second;
|
|
}
|
|
|
|
LinConvertResult result;
|
|
|
|
FuncGraphPtr fg;
|
|
AnfNodePtrList inputs;
|
|
AnfNodePtrList outputs;
|
|
|
|
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
|
|
result.inputs = inputs;
|
|
result.outputs = outputs;
|
|
result.graph_id = kInvalidGraphId;
|
|
auto graph_id = sess_->CompileGraph(lst, outputs);
|
|
if (MsContext::GetInstance()->precompile_only()) {
|
|
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
|
|
return result;
|
|
}
|
|
|
|
result.run = std::make_shared<RunFunc>(
|
|
[graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); });
|
|
MS_EXCEPTION_IF_NULL(result.run);
|
|
|
|
result.simu_run = std::make_shared<RunFunc>(
|
|
[graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); });
|
|
MS_EXCEPTION_IF_NULL(result.simu_run);
|
|
result.graph_id = graph_id;
|
|
|
|
graph_id_map_[graph_id] = result;
|
|
(void)g_ConvertCache.emplace(lst, result);
|
|
return result;
|
|
}
|
|
|
|
void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
|
|
GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
|
|
|
|
GraphId cond_g = kInvalidGraphId;
|
|
if (utils::isa<AnfNodePtr>(c)) {
|
|
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
|
|
} else {
|
|
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
|
|
}
|
|
auto before_cond = curr_switch_;
|
|
if (curr_switch_.hash() != c.hash()) {
|
|
// invoke while false->before true call
|
|
if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
|
|
active_g = simu_cond_map_[before_cond].cond_graph_map[false];
|
|
} else {
|
|
active_g = kInvalidGraphId;
|
|
}
|
|
// while x < y:
|
|
// z = y + 1
|
|
// while z < c2:
|
|
// out = out + 1
|
|
// z = z + 1
|
|
if (active_g == cond_g) {
|
|
active_g = kInvalidGraphId;
|
|
simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
|
|
}
|
|
MS_LOG(DEBUG) << "invoke set active:" << active_g;
|
|
}
|
|
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
|
|
sess_->SetActive(active_g, cond_g);
|
|
}
|
|
|
|
void MsBackend::SetSwitchGraph() {
|
|
MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
|
|
|
|
if (is_switch_call_) {
|
|
GraphId false_g = kInvalidGraphId;
|
|
GraphId true_g = kInvalidGraphId;
|
|
MS_LOG(DEBUG) << "start SetSwitchGraph";
|
|
true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
|
|
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
|
|
if (!curr_cond) {
|
|
if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
|
|
// has false branch
|
|
false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
|
|
}
|
|
GraphId cond_g = kInvalidGraphId;
|
|
if (utils::isa<AnfNodePtr>(curr_switch_)) {
|
|
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
|
|
} else {
|
|
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
|
|
}
|
|
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
|
|
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
|
|
}
|
|
is_switch_call_ = false;
|
|
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
|
|
}
|
|
}
|
|
|
|
// convert node from formal parameter to actual parameter,
|
|
// and actual parameter is graph user's formal parameter.
|
|
// get top while graph's parameter in recall while.
|
|
AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
|
std::unordered_map<AnfNodePtr, size_t> params_index;
|
|
auto result = node;
|
|
auto graph = result->func_graph();
|
|
while (func_graph != graph) {
|
|
auto iter = graph_user_inputs_.find(graph);
|
|
if (iter == graph_user_inputs_.end()) {
|
|
break;
|
|
}
|
|
|
|
params_index.clear();
|
|
auto ¶ms = graph->parameters();
|
|
for (size_t i = 0; i < params.size(); ++i) {
|
|
params_index[params[i]] = i;
|
|
}
|
|
|
|
graph = iter->second.first;
|
|
auto &inputs = iter->second.second;
|
|
result = inputs[params_index[result]];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
|
|
const AnfNodePtrList &inputs) {
|
|
if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
|
|
return;
|
|
}
|
|
graph_user_inputs_[func_graph] = {user, inputs};
|
|
}
|
|
|
|
void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
|
|
std::unordered_map<AnfNodePtr, size_t> params_index;
|
|
auto ¶ms = func_graph->parameters();
|
|
for (size_t i = 0; i < params.size(); ++i) {
|
|
params_index[params[i]] = i;
|
|
}
|
|
|
|
// recall all child graphs in this while
|
|
auto &graph_inputs = graph_inputs_[c];
|
|
for (auto &iter : graph_inputs) {
|
|
auto &graph = iter.first;
|
|
auto &old_args = iter.second;
|
|
auto &result = graph_id_map_[graph];
|
|
auto &inputs = result.inputs;
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
auto input = ConvertGraphInput(func_graph, inputs[i]);
|
|
auto it = params_index.find(input);
|
|
if (it != params_index.end()) {
|
|
old_args[i] = args[it->second];
|
|
}
|
|
}
|
|
sess_->SetChildGraphInput(graph, old_args);
|
|
}
|
|
graph_inputs_.erase(c);
|
|
}
|
|
|
|
// compile set input output
|
|
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
|
MS_LOG(DEBUG) << "set graph input:" << g;
|
|
// switch maybe twice
|
|
sess_->SetChildGraphInput(g, args);
|
|
|
|
if (is_switch_call_) {
|
|
if (!curr_switch_.is_null()) {
|
|
// push this {g, args} to all user while graph_inputs for nest while,
|
|
// when current condition recall over delete this cond in graph_inputs.
|
|
for (auto &iter : graph_inputs_) {
|
|
iter.second.push_back({g, args});
|
|
}
|
|
if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
|
|
graph_inputs_[curr_switch_].push_back({g, args});
|
|
}
|
|
}
|
|
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
|
|
MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
|
|
simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
|
|
SetSwitchGraph();
|
|
}
|
|
|
|
std::vector<BaseRef> outputs;
|
|
(void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
|
|
[](const AnfNodePtr &v) { return v; });
|
|
return VectorRef(outputs);
|
|
}
|
|
|
|
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
|
|
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
|
|
// Run graph
|
|
std::vector<tensor::TensorPtr> inputs;
|
|
for (const auto &arg : args) {
|
|
if (utils::isa<tensor::TensorPtr>(arg)) {
|
|
auto value = utils::cast<tensor::TensorPtr>(arg);
|
|
inputs.push_back(value);
|
|
} else if (utils::isa<ValuePtr>(arg)) {
|
|
auto value = utils::cast<ValuePtr>(arg);
|
|
if (value->isa<ValueTuple>()) {
|
|
(void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(),
|
|
std::back_inserter(inputs),
|
|
[](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
|
|
} else if (value->isa<Scalar>()) {
|
|
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
|
|
MS_EXCEPTION_IF_NULL(scalar_tensor);
|
|
inputs.push_back(scalar_tensor);
|
|
} else {
|
|
inputs.push_back(value->cast<tensor::TensorPtr>());
|
|
}
|
|
} else if (utils::isa<PyObjectRef>(arg)) {
|
|
auto value = utils::cast<PyObjectRef>(arg).object_;
|
|
inputs.push_back(py::cast<tensor::TensorPtr>(value));
|
|
} else if (utils::isa<VectorRefPtr>(arg)) {
|
|
auto args_new = utils::cast<VectorRef>(arg);
|
|
(void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs),
|
|
[](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); });
|
|
} else {
|
|
MS_LOG(WARNING) << "Invalid input type.";
|
|
}
|
|
}
|
|
|
|
VectorRef outputs;
|
|
// call ms rungraph (graphId, input ,output)
|
|
sess_->RunGraph(g, inputs, &outputs);
|
|
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
|
|
return outputs;
|
|
}
|
|
|
|
SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
|
|
MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
|
|
|
|
CondGraph cond_graph;
|
|
cond_graph.curr_cond = value;
|
|
if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
|
|
simu_cond_map_[c] = cond_graph;
|
|
}
|
|
|
|
if (simu_cond_map_[c].cond_graph_map.count(value)) {
|
|
return kCondAlreadyRun;
|
|
}
|
|
simu_cond_map_[c].curr_cond = value;
|
|
MS_LOG(DEBUG) << "end set cond ";
|
|
return kCondOk;
|
|
}
|
|
|
|
void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
|
|
MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
|
|
std::vector<BaseRef> args;
|
|
auto parameters = root->parameters();
|
|
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
|
|
[](const AnfNodePtr &v) { return v; });
|
|
MS_LOG(DEBUG) << "Simulate start";
|
|
(void)sess_->SetFinalGraphInput(parameters);
|
|
BaseRef output = rt->Eval(VectorRef(args));
|
|
sess_->SetFinalGraphOutput(output);
|
|
MS_LOG(DEBUG) << "Simulate Eval end";
|
|
}
|
|
|
|
void MsBackend::Link(GraphId graph_id) {
|
|
if (graph_id == kInvalidGraphId) {
|
|
graph_id = sess_->GetFinalRunGraph();
|
|
}
|
|
sess_->BuildGraph(graph_id);
|
|
}
|
|
|
|
Backend::Backend(const std::string &name) : name_(name) {
|
|
MS_LOG(DEBUG) << "select backend:" << name;
|
|
convert_fn_ = backends[name_];
|
|
is_switch_call_ = false;
|
|
is_multi_graph_sink_ = false;
|
|
simu_flag_ = false;
|
|
}
|
|
|
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
|
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1);
|
|
sess_ = session::SessionFactory::Get().Create(target);
|
|
if (sess_ == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
|
}
|
|
sess_->Init(device_id);
|
|
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
|
}
|
|
|
|
} // namespace compile
|
|
} // namespace mindspore
|