You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/vm/backend.cc

332 lines
12 KiB

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "vm/backend.h"
#include <algorithm>
#include <vector>
#include "utils/log_adapter.h"
#include "ir/anf.h"
#include "utils/callbacks.h"
#include "utils/graph_utils.h"
#include "session/session_factory.h"
#include "common/utils.h"
#ifdef ENABLE_GE
#include "utils/callbacks_ge.h"
#endif
namespace mindspore {
namespace compile {
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
// multi_graph merge to one, big graph have paramters in begin and only have one output
MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
multi_result_.inputs = g->parameters();
final_output_ = NewValueNode("fake_output");
multi_result_.outputs = {final_output_};
GraphId final_g = sess_->GetFinalRunGraph();
multi_result_.run = std::make_shared<RunFunc>(
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); });
return multi_result_;
}
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
MS_LOG(DEBUG) << "MsConvert";
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
auto cached = g_ConvertCache.find(lst);
if (cached != g_ConvertCache.end()) {
return cached->second;
}
LinConvertResult result;
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
result.inputs = inputs;
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
auto graph_id = sess_->CompileGraph(lst, outputs);
if (MsContext::GetInstance()->precompile_only()) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
}
result.run = std::make_shared<RunFunc>(
[graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); });
MS_EXCEPTION_IF_NULL(result.run);
result.simu_run = std::make_shared<RunFunc>(
[graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); });
MS_EXCEPTION_IF_NULL(result.simu_run);
result.graph_id = graph_id;
graph_id_map_[graph_id] = result;
(void)g_ConvertCache.emplace(lst, result);
return result;
}
void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(c)) {
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
} else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
}
auto before_cond = curr_switch_;
if (curr_switch_.hash() != c.hash()) {
// invoke while false->before true call
if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
active_g = simu_cond_map_[before_cond].cond_graph_map[false];
} else {
active_g = kInvalidGraphId;
}
// while x < y:
// z = y + 1
// while z < c2:
// out = out + 1
// z = z + 1
if (active_g == cond_g) {
active_g = kInvalidGraphId;
simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
}
MS_LOG(DEBUG) << "invoke set active:" << active_g;
}
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
sess_->SetActive(active_g, cond_g);
}
void MsBackend::SetSwitchGraph() {
MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
if (is_switch_call_) {
GraphId false_g = kInvalidGraphId;
GraphId true_g = kInvalidGraphId;
MS_LOG(DEBUG) << "start SetSwitchGraph";
true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
if (!curr_cond) {
if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
// has false branch
false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
}
GraphId cond_g = kInvalidGraphId;
if (utils::isa<AnfNodePtr>(curr_switch_)) {
cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
} else {
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
}
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
}
is_switch_call_ = false;
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
}
}
// convert node from formal parameter to actual parameter,
// and actual parameter is graph user's formal parameter.
// get top while graph's parameter in recall while.
AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
std::unordered_map<AnfNodePtr, size_t> params_index;
auto result = node;
auto graph = result->func_graph();
while (func_graph != graph) {
auto iter = graph_user_inputs_.find(graph);
if (iter == graph_user_inputs_.end()) {
break;
}
params_index.clear();
auto &params = graph->parameters();
for (size_t i = 0; i < params.size(); ++i) {
params_index[params[i]] = i;
}
graph = iter->second.first;
auto &inputs = iter->second.second;
result = inputs[params_index[result]];
}
return result;
}
void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
const AnfNodePtrList &inputs) {
if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
return;
}
graph_user_inputs_[func_graph] = {user, inputs};
}
void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
std::unordered_map<AnfNodePtr, size_t> params_index;
auto &params = func_graph->parameters();
for (size_t i = 0; i < params.size(); ++i) {
params_index[params[i]] = i;
}
// recall all child graphs in this while
auto &graph_inputs = graph_inputs_[c];
for (auto &iter : graph_inputs) {
auto &graph = iter.first;
auto &old_args = iter.second;
auto &result = graph_id_map_[graph];
auto &inputs = result.inputs;
for (size_t i = 0; i < inputs.size(); ++i) {
auto input = ConvertGraphInput(func_graph, inputs[i]);
auto it = params_index.find(input);
if (it != params_index.end()) {
old_args[i] = args[it->second];
}
}
sess_->SetChildGraphInput(graph, old_args);
}
graph_inputs_.erase(c);
}
// compile set input output
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
MS_LOG(DEBUG) << "set graph input:" << g;
// switch maybe twice
sess_->SetChildGraphInput(g, args);
if (is_switch_call_) {
if (!curr_switch_.is_null()) {
// push this {g, args} to all user while graph_inputs for nest while,
// when current condition recall over delete this cond in graph_inputs.
for (auto &iter : graph_inputs_) {
iter.second.push_back({g, args});
}
if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
graph_inputs_[curr_switch_].push_back({g, args});
}
}
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
SetSwitchGraph();
}
std::vector<BaseRef> outputs;
(void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
[](const AnfNodePtr &v) { return v; });
return VectorRef(outputs);
}
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
// Run graph
std::vector<tensor::TensorPtr> inputs;
for (const auto &arg : args) {
if (utils::isa<tensor::TensorPtr>(arg)) {
auto value = utils::cast<tensor::TensorPtr>(arg);
inputs.push_back(value);
} else if (utils::isa<ValuePtr>(arg)) {
auto value = utils::cast<ValuePtr>(arg);
if (value->isa<ValueTuple>()) {
(void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(),
std::back_inserter(inputs),
[](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
} else if (value->isa<Scalar>()) {
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
MS_EXCEPTION_IF_NULL(scalar_tensor);
inputs.push_back(scalar_tensor);
} else {
inputs.push_back(value->cast<tensor::TensorPtr>());
}
} else if (utils::isa<PyObjectRef>(arg)) {
auto value = utils::cast<PyObjectRef>(arg).object_;
inputs.push_back(py::cast<tensor::TensorPtr>(value));
} else if (utils::isa<VectorRefPtr>(arg)) {
auto args_new = utils::cast<VectorRef>(arg);
(void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs),
[](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); });
} else {
MS_LOG(WARNING) << "Invalid input type.";
}
}
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
sess_->RunGraph(g, inputs, &outputs);
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
return outputs;
}
SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
CondGraph cond_graph;
cond_graph.curr_cond = value;
if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
simu_cond_map_[c] = cond_graph;
}
if (simu_cond_map_[c].cond_graph_map.count(value)) {
return kCondAlreadyRun;
}
simu_cond_map_[c].curr_cond = value;
MS_LOG(DEBUG) << "end set cond ";
return kCondOk;
}
void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
std::vector<BaseRef> args;
auto parameters = root->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
[](const AnfNodePtr &v) { return v; });
MS_LOG(DEBUG) << "Simulate start";
(void)sess_->SetFinalGraphInput(parameters);
BaseRef output = rt->Eval(VectorRef(args));
sess_->SetFinalGraphOutput(output);
MS_LOG(DEBUG) << "Simulate Eval end";
}
void MsBackend::Link(GraphId graph_id) {
if (graph_id == kInvalidGraphId) {
graph_id = sess_->GetFinalRunGraph();
}
sess_->BuildGraph(graph_id);
}
Backend::Backend(const std::string &name) : name_(name) {
MS_LOG(DEBUG) << "select backend:" << name;
convert_fn_ = backends[name_];
is_switch_call_ = false;
is_multi_graph_sink_ = false;
simu_flag_ = false;
}
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1);
sess_ = session::SessionFactory::Get().Create(target);
if (sess_ == nullptr) {
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
}
sess_->Init(device_id);
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
}
} // namespace compile
} // namespace mindspore