|
|
|
@ -521,6 +521,73 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
|
|
|
|
|
return rt;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GraphCompiler::GraphCompiler(const std::shared_ptr<MindRTBackend> &backend, const std::vector<PrimitivePtr> &cut_list)
|
|
|
|
|
: backend_(backend) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(backend_);
|
|
|
|
|
if (backend_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "The backend isn't created.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
FuncGraphPtr root_graph = WrapPrimitives(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root_graph);
|
|
|
|
|
|
|
|
|
|
// Compile root graph.
|
|
|
|
|
auto root_graph_id = CompileGraph(root_graph);
|
|
|
|
|
|
|
|
|
|
// Compile sub graphs.
|
|
|
|
|
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
|
|
|
|
|
for (auto sub_graph : sub_graphs) {
|
|
|
|
|
if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
|
|
|
|
(void)CompileGraph(sub_graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return root_graph_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_partition_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(backend_);
|
|
|
|
|
|
|
|
|
|
// Split graph to segments.
|
|
|
|
|
const auto &segments = graph_partition_->Partition(func_graph);
|
|
|
|
|
MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
|
|
|
|
|
|
|
|
|
|
// Foreach the segments to compile graph.
|
|
|
|
|
std::vector<uint32_t> graph_ids;
|
|
|
|
|
for (const auto &segment : segments) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(segment);
|
|
|
|
|
// Compile the normal nodes, which doesn't contain the cut node.
|
|
|
|
|
if (!segment->is_cut_) {
|
|
|
|
|
if (segment->nodes_.size() == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The segments size is 0.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope();
|
|
|
|
|
|
|
|
|
|
// Compile the anfNodes list to kernelGraph, return the graph id of kernelGraph.
|
|
|
|
|
auto graph_id = backend_->CompileGraph(segment->nodes_);
|
|
|
|
|
graph_ids.emplace_back(graph_id);
|
|
|
|
|
} else {
|
|
|
|
|
// Compile the cut node.
|
|
|
|
|
auto cut_node = segment->nodes_[0];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cut_node);
|
|
|
|
|
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return graph_ids[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
|
|
|
|
|
// Return false in the transitional stage.
|
|
|
|
|
bool IsMindRTUsed() { return false; }
|
|
|
|
|
|
|
|
|
|
BackendPtr CreateBackend() {
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
@ -533,7 +600,13 @@ BackendPtr CreateBackend() {
|
|
|
|
|
if (name == kMsConvert) {
|
|
|
|
|
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
|
|
|
|
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
|
|
|
|
auto backend = std::make_shared<MsBackend>(name, target, device_id);
|
|
|
|
|
BackendPtr backend = nullptr;
|
|
|
|
|
// Create MindRTBackend or MsBackend according to whether mindrt is used.
|
|
|
|
|
if (IsMindRTUsed()) {
|
|
|
|
|
backend = std::make_shared<MindRTBackend>(name, target, device_id);
|
|
|
|
|
} else {
|
|
|
|
|
backend = std::make_shared<MsBackend>(name, target, device_id);
|
|
|
|
|
}
|
|
|
|
|
std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
|
|
|
|
if (device_target == kAscendDevice) {
|
|
|
|
|
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
|
|
|
|