You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
665 lines
21 KiB
665 lines
21 KiB
/**
|
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
*
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "vm/transform.h"
|
|
|
|
#include <algorithm>
|
|
#include <map>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "pipeline/static_analysis/abstract_value.h"
|
|
#ifdef ENABLE_GE
|
|
#include "transform/convert.h"
|
|
#endif
|
|
#include "utils/graph_utils.h"
|
|
#include "utils/context/ms_context.h"
|
|
#include "debug/trace.h"
|
|
|
|
namespace mindspore {
|
|
namespace compile {
|
|
using mindspore::abstract::AbstractFunction;
|
|
using mindspore::abstract::AbstractFunctionPtr;
|
|
using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
|
|
using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
|
|
using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
|
|
|
|
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
|
prim::kPrimMakeTuple, prim::kPrimBpropCut};
|
|
const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
|
|
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
|
prim::kPrimBpropCut};
|
|
return ms_nonlinear_ops;
|
|
}
|
|
|
|
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
|
|
: backend_(backend), cut_list_(cut_list) {
|
|
MS_EXCEPTION_IF_NULL(backend_);
|
|
lin_convert_ = backend_->convert_fn();
|
|
if (lin_convert_ == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
|
|
}
|
|
|
|
is_gevm_convert_ = false;
|
|
if (backend->name() == kGeVm) {
|
|
MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true";
|
|
is_gevm_convert_ = true;
|
|
}
|
|
}
|
|
|
|
bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
if (node->isa<CNode>()) {
|
|
auto cnode = node->cast<CNodePtr>();
|
|
auto &inputs = cnode->inputs();
|
|
if (inputs.empty()) {
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
|
}
|
|
|
|
AnfNodePtr fn = inputs[0];
|
|
if (!IsValueNode<Primitive>(fn)) {
|
|
return true;
|
|
}
|
|
|
|
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
|
|
for (auto &prim : cut_list_) {
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
if (prim->name() == node_prim->name()) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
#ifdef ENABLE_GE
|
|
if (is_gevm_convert_) {
|
|
auto name = GetCNodeFuncName(cnode);
|
|
auto adpt = transform::DfGraphConvertor::FindAdapter(name);
|
|
if (adpt == nullptr) {
|
|
return true;
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
VectorRef splits;
|
|
VectorRef split;
|
|
std::vector<AnfNodePtr> nodes = TopoSort(graph->get_return());
|
|
|
|
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
|
for (auto &node : nodes) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
if (IsCut(node)) {
|
|
MS_LOG(DEBUG) << "Cut node:" << node->DebugString(10) << ", size:" << split.size();
|
|
if (split.size() != 0) {
|
|
splits.push_back(split);
|
|
}
|
|
splits.push_back(node);
|
|
split.clear();
|
|
} else if (!(node->isa<ValueNode>() || node->isa<Parameter>())) {
|
|
split.push_back(node);
|
|
MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size();
|
|
}
|
|
}
|
|
MS_LOG(DEBUG) << "Split node size :" << splits.size();
|
|
return splits;
|
|
}
|
|
|
|
// Push the value node on the stack.
|
|
void CompileGraph::Push(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
if (slots_.count(node) > 0) {
|
|
MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString()
|
|
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
|
}
|
|
MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
|
|
<< " is parameter: " << node->isa<Parameter>();
|
|
slots_[node] = height_;
|
|
set_height(height_ + 1);
|
|
}
|
|
|
|
void CompileGraph::AddInst(const Instruction &inst, const int &arg) {
|
|
VectorRef args;
|
|
args.push_back(arg);
|
|
AddInst(inst, args);
|
|
}
|
|
|
|
void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
|
|
VectorRef args;
|
|
args.push_back(arg);
|
|
AddInst(inst, args);
|
|
}
|
|
|
|
void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
|
|
inst_.push_back(std::make_pair(inst, args));
|
|
}
|
|
|
|
// Gets the stack reference for the node value. If the node is a constant,
|
|
// it may actually cause the push in to not be mentioned before.
|
|
int CompileGraph::Ref(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
|
|
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
|
|
if (IsValueNode<FuncGraph>(node)) {
|
|
MS_LOG(DEBUG) << "Push graph.";
|
|
AddInst(Instruction::kGraph, GetValueNode(node));
|
|
} else {
|
|
MS_LOG(DEBUG) << "Push.";
|
|
if (IsValueNode<Primitive>(node)) {
|
|
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
|
} else {
|
|
AddInst(Instruction::kPush, GetValueNode(node));
|
|
}
|
|
}
|
|
Push(node);
|
|
}
|
|
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
|
|
<< ", return: " << slots_[node] - height_;
|
|
return slots_[node] - height_;
|
|
}
|
|
|
|
// Make sure the value of node is at the top of the stack.
|
|
void CompileGraph::AddInput(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
if (slots_.count(node) == 0) {
|
|
MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
|
|
(void)Ref(node);
|
|
return;
|
|
}
|
|
AddInst(Instruction::kInput, Ref(node));
|
|
set_height(height_ + 1);
|
|
}
|
|
|
|
// Call back effect in stack
|
|
void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); }
|
|
|
|
void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
std::vector<AnfNodePtr> parameters = graph->parameters();
|
|
for (size_t i = parameters.size(); i != 0; i--) {
|
|
Push(parameters[i - 1]);
|
|
MS_LOG(DEBUG) << "Push parameter " << i - 1 << ": " << parameters[i - 1]->DebugString(true);
|
|
}
|
|
}
|
|
|
|
int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) {
|
|
MS_LOG(DEBUG) << "LinConvert start";
|
|
LinConvertResult result;
|
|
|
|
if (backend_->simu_flag()) {
|
|
result = backend_->GetMultiGraphRun(graph);
|
|
} else {
|
|
result = lin_convert_(node_list);
|
|
}
|
|
|
|
if (result.run == nullptr) {
|
|
MS_LOG(ERROR) << "LinConvert failed";
|
|
return RET_FAILED;
|
|
}
|
|
|
|
if (!(*result.run)) {
|
|
if (result.inputs.size() != result.outputs.size()) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
|
|
} else {
|
|
size_t size = result.inputs.size();
|
|
for (size_t i = 0; i < size; i++) {
|
|
Tie(result.inputs[i], result.outputs[i]);
|
|
}
|
|
return RET_CONTINUE;
|
|
}
|
|
}
|
|
AddExternal(result);
|
|
for (auto &o : result.outputs) {
|
|
Push(o);
|
|
}
|
|
|
|
return RET_SUCCESS;
|
|
}
|
|
|
|
void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
|
|
MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
|
|
if (backend_->is_multi_graph_sink()) {
|
|
VectorRef args;
|
|
args.emplace_back(-1);
|
|
MS_LOG(DEBUG) << "call::" << height_;
|
|
AddInst(Instruction::kCall, args);
|
|
|
|
args.clear();
|
|
args.emplace_back(node->input(1));
|
|
AddInst(Instruction::kSwitchReturn, args);
|
|
|
|
args.clear();
|
|
args.emplace_back(false);
|
|
args.emplace_back(Ref(node->input(1)));
|
|
args.emplace_back(Ref(node->input(2)));
|
|
args.emplace_back(Ref(node->input(3)));
|
|
AddInst(Instruction::kSwitch, args);
|
|
}
|
|
}
|
|
|
|
int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
|
|
std::vector<AnfNodePtr> node_inputs = node->inputs();
|
|
if (node_inputs.empty()) {
|
|
MS_LOG(EXCEPTION) << "The node->inputs() is empty";
|
|
}
|
|
AnfNodePtr fn = node_inputs[0];
|
|
if (IsValueNode<Primitive>(fn)) {
|
|
PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
|
|
MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
|
|
for (size_t i = node_inputs.size() - 1; i > 0; i--) {
|
|
AddInput(node->input(i));
|
|
}
|
|
if (IsPrimitive(fn, prim::kPrimReturn)) {
|
|
AddReturn(node);
|
|
return RET_BREAK;
|
|
}
|
|
if (IsPrimitive(fn, prim::kPrimPartial)) {
|
|
AddPartial(node);
|
|
} else if (IsPrimitive(fn, prim::kPrimSwitch)) {
|
|
AddSwitch(node);
|
|
AddSinkSwitch(node);
|
|
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
|
|
AddMakeTuple(node);
|
|
} else {
|
|
AddPrimitive(node, value);
|
|
}
|
|
} else {
|
|
int ret = AddCall(graph, node);
|
|
if (ret == RET_BREAK) {
|
|
return ret;
|
|
}
|
|
}
|
|
Push(node);
|
|
return RET_SUCCESS;
|
|
}
|
|
|
|
void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
|
|
auto ret = LinConvert(graph, {});
|
|
if (ret == RET_FAILED) {
|
|
MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
|
|
}
|
|
AddReturn(nullptr);
|
|
}
|
|
|
|
bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
|
|
MS_LOG(DEBUG) << "Start split graph";
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
VectorRef splits = SplitNodes(graph);
|
|
|
|
MS_LOG(DEBUG) << "Split nodes size:" << splits.size();
|
|
for (auto &split : splits) {
|
|
int ret = RET_SUCCESS;
|
|
if (utils::isa<VectorRef>(split)) {
|
|
MS_LOG(DEBUG) << "Start a extern LinConvert";
|
|
std::vector<AnfNodePtr> args;
|
|
auto vec_ref = utils::cast<VectorRef>(split);
|
|
(void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
|
|
[](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
|
|
ret = LinConvert(graph, args);
|
|
MS_LOG(DEBUG) << "End a extern LinConvert";
|
|
if (ret == RET_FAILED) {
|
|
return false;
|
|
}
|
|
if (ret == RET_CONTINUE) {
|
|
continue;
|
|
}
|
|
} else {
|
|
MS_LOG(DEBUG) << "Start a cut node";
|
|
if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) {
|
|
MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
|
|
}
|
|
CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>();
|
|
ret = InterpretNode(graph, node);
|
|
MS_LOG(DEBUG) << "End a cut node";
|
|
if (ret == RET_BREAK) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
MS_LOG(DEBUG) << "End split graph";
|
|
return true;
|
|
}
|
|
|
|
InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
|
|
InstSet inst = Run(graph);
|
|
return inst;
|
|
}
|
|
|
|
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
MS_LOG(DEBUG) << "Compile start graph: " << graph->ToString();
|
|
|
|
Reset();
|
|
PushParameters(graph);
|
|
int param_height = height_;
|
|
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
|
|
|
|
if (backend_->simu_flag()) {
|
|
GenMultiGraphsRun(graph);
|
|
} else {
|
|
if (!SplitGraph(graph)) {
|
|
return inst_;
|
|
}
|
|
}
|
|
|
|
AddPadStack(param_height);
|
|
auto ret = inst_;
|
|
Reset();
|
|
return ret;
|
|
}
|
|
|
|
void CompileGraph::AddPadStack(int param_height) {
|
|
int stack_sizes = max_height_ - param_height;
|
|
MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
|
|
<< " need_stack:" << stack_sizes;
|
|
if (stack_sizes > 0) {
|
|
VectorRef need_stacks({stack_sizes});
|
|
(void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
|
|
}
|
|
}
|
|
|
|
void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
|
|
VectorRef args;
|
|
args.emplace_back(Ref(fn));
|
|
args.emplace_back(height_);
|
|
args.emplace_back(static_cast<int>(size - 1));
|
|
MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
|
|
AddInst(Instruction::kTailCall, args);
|
|
}
|
|
|
|
void CompileGraph::AddPartial(const CNodePtr &node) {
|
|
auto inputs = node->inputs();
|
|
VectorRef args;
|
|
auto fn = inputs[1];
|
|
if (!IsValueNode<FuncGraph>(fn)) {
|
|
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
|
|
}
|
|
if (backend_->is_multi_graph_sink()) {
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
|
|
args.emplace_back(func_graph);
|
|
AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
|
|
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
|
|
}
|
|
for (size_t i = 1; i < inputs.size(); i++) {
|
|
args.emplace_back(Ref(inputs[i]));
|
|
}
|
|
AddInst(Instruction::kPartial, args);
|
|
}
|
|
|
|
void CompileGraph::AddMakeTuple(const CNodePtr &node) {
|
|
auto inputs = node->inputs();
|
|
VectorRef args;
|
|
for (size_t i = 1; i < inputs.size(); i++) {
|
|
args.emplace_back(Ref(inputs[i]));
|
|
}
|
|
AddInst(Instruction::kTuple, args);
|
|
}
|
|
|
|
void CompileGraph::AddSwitch(const CNodePtr &node) {
|
|
auto inputs = node->inputs();
|
|
if (inputs.size() < 4) {
|
|
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
|
}
|
|
VectorRef args;
|
|
if (backend_->is_multi_graph_sink()) {
|
|
args.emplace_back(true);
|
|
}
|
|
args.emplace_back(Ref(inputs[1]));
|
|
args.emplace_back(Ref(inputs[2]));
|
|
args.emplace_back(Ref(inputs[3]));
|
|
AddInst(Instruction::kSwitch, args);
|
|
}
|
|
|
|
void CompileGraph::AddReturn(const CNodePtr &node) {
|
|
VectorRef args;
|
|
if (backend_->simu_flag()) {
|
|
args.emplace_back(Ref(backend_->final_output()));
|
|
} else {
|
|
args.emplace_back(Ref(node->input(1)));
|
|
}
|
|
args.emplace_back(height_);
|
|
AddInst(Instruction::kReturn, args);
|
|
}
|
|
|
|
void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
|
|
auto inputs = node->inputs();
|
|
VectorRef args;
|
|
args.push_back(prim);
|
|
for (size_t i = 1; i < inputs.size(); i++) {
|
|
args.emplace_back(Ref(inputs[i]));
|
|
}
|
|
AddInst(Instruction::kPrim, args);
|
|
}
|
|
|
|
int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
|
|
auto inputs = node->inputs();
|
|
AnfNodePtr fn = inputs[0];
|
|
if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
|
|
AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
|
|
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
|
|
}
|
|
(void)Ref(fn);
|
|
size_t size = inputs.size();
|
|
for (size_t i = size - 1; i > 0; i--) {
|
|
AddInput(inputs[i]);
|
|
}
|
|
if (node == graph->output()) {
|
|
AddTailCall(fn, size);
|
|
return RET_BREAK;
|
|
}
|
|
MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << size - 1;
|
|
AddInst(Instruction::kCall, Ref(fn));
|
|
Ret(static_cast<int>(size - 1));
|
|
return RET_SUCCESS;
|
|
}
|
|
|
|
void CompileGraph::AddExternal(const LinConvertResult &result) {
|
|
VectorRef args;
|
|
args.push_back(result.run);
|
|
args.push_back(result.simu_run);
|
|
size_t size = result.inputs.size();
|
|
for (size_t i = 0; i < size; i++) {
|
|
args.emplace_back(Ref(result.inputs[i]));
|
|
}
|
|
AddInst(Instruction::kExternal, args);
|
|
}
|
|
|
|
void TraverseGraphMap(
|
|
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
|
|
const FuncGraphSet &fgs,
|
|
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
|
|
MS_EXCEPTION_IF_NULL(manager_ptr);
|
|
MS_EXCEPTION_IF_NULL(tr);
|
|
for (const auto &fg : fgs) {
|
|
for (const auto &ct_any : fg->value_nodes()) {
|
|
AnfNodePtr const_primitive_node = ct_any.first;
|
|
if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
|
|
auto users = manager_ptr->node_users()[const_primitive_node];
|
|
for (auto &use : users) {
|
|
CNodePtr node = use.first->cast<CNodePtr>();
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
int key = use.second;
|
|
if (key != 0) {
|
|
MS_EXCEPTION_IF_NULL(node->input(0));
|
|
bool key_is_const = node->input(0)->isa<ValueNode>();
|
|
PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
|
|
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
|
|
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
|
|
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
|
|
continue;
|
|
}
|
|
FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
|
|
dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
|
|
tr->SetEdge(node, key, NewValueNode(g));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
FuncGraphManagerPtr manager_ptr = graph->manager();
|
|
MS_EXCEPTION_IF_NULL(manager_ptr);
|
|
MapPrimTypeFuncGraph prim_graphs;
|
|
auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
|
|
PrimTypePair prim_type = std::make_pair(prim, type);
|
|
if (prim_graphs.end() == prim_graphs.find(prim_type)) {
|
|
FuncGraphPtr g = std::make_shared<FuncGraph>();
|
|
std::vector<AnfNodePtr> args;
|
|
ValueNodePtr prim_ct = NewValueNode(prim);
|
|
MS_EXCEPTION_IF_NULL(prim_ct);
|
|
prim_ct->set_abstract(type);
|
|
args.push_back(prim_ct);
|
|
MS_EXCEPTION_IF_NULL(type);
|
|
TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
|
|
MS_EXCEPTION_IF_NULL(tp);
|
|
MS_EXCEPTION_IF_NULL(g);
|
|
for (auto t : tp->args_spec_list()) {
|
|
ParameterPtr p = g->add_parameter();
|
|
p->set_abstract(t);
|
|
args.push_back(p);
|
|
}
|
|
AnfNodePtr out = g->NewCNode(args);
|
|
out->set_abstract(tp->output());
|
|
g->set_output(out);
|
|
prim_graphs[prim_type] = g;
|
|
}
|
|
|
|
return prim_graphs[prim_type];
|
|
};
|
|
|
|
FuncGraphTransaction tr = manager_ptr->Transact();
|
|
auto &fgs = manager_ptr->func_graphs();
|
|
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
|
|
|
|
return graph;
|
|
}
|
|
|
|
CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
|
|
MS_EXCEPTION_IF_NULL(backend);
|
|
MS_LOG(DEBUG) << "Start vm: " << backend->name();
|
|
transform_ = std::make_shared<CompileGraph>(backend, cut_list);
|
|
Reset();
|
|
}
|
|
|
|
// Convert graphs to unlinked instructions.
|
|
void CompileGraphs::Compile(const FuncGraphPtr &graph) {
|
|
MS_LOG(DEBUG) << "Start";
|
|
auto graph_manager = graph->manager();
|
|
MS_EXCEPTION_IF_NULL(graph_manager);
|
|
FuncGraphSet graphs = graph_manager->func_graphs();
|
|
for (auto &g : graphs) {
|
|
mapping_[g] = static_cast<int>(insts_.size());
|
|
if (transform_ != nullptr) {
|
|
InstSet insts = transform_->Run(g);
|
|
if (!insts.empty()) {
|
|
(void)insts_.insert(insts_.end(), insts.begin(), insts.end());
|
|
}
|
|
}
|
|
}
|
|
MS_LOG(DEBUG) << "End";
|
|
}
|
|
|
|
// Link instructions from multiple function graphs together.
|
|
FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
|
|
MS_LOG(DEBUG) << "Start";
|
|
for (std::size_t i = 0; i < insts_.size(); i++) {
|
|
InstType inst = insts_[i];
|
|
MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
|
|
if (Instruction::kGraph == inst.first) {
|
|
if (inst.second.empty()) {
|
|
MS_LOG(EXCEPTION) << "The second element of inst is empty";
|
|
}
|
|
FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
|
|
MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
|
|
insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
|
|
}
|
|
}
|
|
|
|
FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
|
|
if (backend_->is_multi_graph_sink()) {
|
|
backend_->set_simu_flag(true);
|
|
MS_LOG(DEBUG) << "Start simulate";
|
|
backend_->SimulateRun(rt, graph);
|
|
MS_LOG(DEBUG) << "Link graphs";
|
|
insts_ = transform_->GenMultiGraphsSinkInst(graph);
|
|
rt->set_insts(insts_);
|
|
backend_->set_simu_flag(false);
|
|
MS_LOG(DEBUG) << "End start simulate";
|
|
backend_->Link(kInvalidGraphId);
|
|
}
|
|
MS_LOG(DEBUG) << "End";
|
|
return rt;
|
|
}
|
|
|
|
// Convert all graphs to unlinked instructions and link them.
|
|
FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
MS_LOG(DEBUG) << "Start";
|
|
Reset();
|
|
MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
|
|
|
|
(void)WrapPrimitives(graph);
|
|
Compile(graph);
|
|
|
|
FinalVMPtr rt = Link(graph);
|
|
Reset();
|
|
MS_LOG(DEBUG) << "End";
|
|
return rt;
|
|
}
|
|
|
|
BackendPtr CreateBackend() {
|
|
auto context_ptr = MsContext::GetInstance();
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
std::string name = context_ptr->backend_policy();
|
|
MS_LOG(INFO) << "CreateBackend is: " << name;
|
|
if (backend_list.count(name) == 0) {
|
|
MS_LOG(EXCEPTION) << "Backend is error: " << name;
|
|
}
|
|
|
|
if (name == kMsConvert) {
|
|
std::string target = context_ptr->device_target();
|
|
uint32_t device_id = context_ptr->device_id();
|
|
auto backend = std::make_shared<MsBackend>(name, target, device_id);
|
|
std::string device_target = MsContext::GetInstance()->device_target();
|
|
if (device_target == kAscendDevice) {
|
|
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
|
backend->set_is_multi_graph_sink(false);
|
|
context_ptr->set_is_multi_graph_sink(false);
|
|
} else {
|
|
backend->set_is_multi_graph_sink(true);
|
|
context_ptr->set_is_multi_graph_sink(true);
|
|
}
|
|
}
|
|
return backend;
|
|
}
|
|
|
|
return std::make_shared<Backend>(name);
|
|
}
|
|
} // namespace compile
|
|
} // namespace mindspore
|