|
|
|
@ -56,55 +56,50 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
|
|
|
|
|
return graph->graph_id();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
|
|
|
|
VectorRef *outputs) {
|
|
|
|
|
GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
|
|
|
std::vector<tensor::TensorPtr> *input_tensors,
|
|
|
|
|
const std::vector<int64_t> &tensors_mask) {
|
|
|
|
|
// Check if the graph cache exists.
|
|
|
|
|
auto iter = run_op_graphs_.find(graph_info);
|
|
|
|
|
if (iter != run_op_graphs_.end()) {
|
|
|
|
|
const auto &graph = iter->second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
return graph->graph_id();
|
|
|
|
|
}
|
|
|
|
|
// Generate kernel graph.
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session_);
|
|
|
|
|
auto graph = session_->GetGraph(graph_id);
|
|
|
|
|
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(actor_set);
|
|
|
|
|
GraphScheduler::GetInstance().Run(actor_set);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
|
|
|
std::vector<tensor::TensorPtr> *input_tensors,
|
|
|
|
|
const std::vector<int64_t> &tensors_mask, VectorRef *outputs) {
|
|
|
|
|
// Check if the graph cache exists.
|
|
|
|
|
if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) {
|
|
|
|
|
// Prepare the graph
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session_);
|
|
|
|
|
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_context_);
|
|
|
|
|
device_context_->SetOperatorInfo(graph->execution_order());
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_context_);
|
|
|
|
|
device_context_->SetOperatorInfo(graph->execution_order());
|
|
|
|
|
device_context_->OptimizeSingleOpGraph(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session_);
|
|
|
|
|
session_->RunOpHideNopNode(graph);
|
|
|
|
|
session_->RunOpRemoveNopNode(graph);
|
|
|
|
|
|
|
|
|
|
device_context_->OptimizeSingleOpGraph(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session_);
|
|
|
|
|
session_->RunOpHideNopNode(graph);
|
|
|
|
|
// Generate 'KernelMod' for kernel in graph.
|
|
|
|
|
device_context_->CreateKernel(graph->execution_order());
|
|
|
|
|
|
|
|
|
|
device_context_->CreateKernel(graph->execution_order());
|
|
|
|
|
run_op_graphs_[graph_info] = graph;
|
|
|
|
|
}
|
|
|
|
|
// Transform graph to actor DAG, contains build and link.
|
|
|
|
|
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
|
|
|
|
|
run_op_graphs_[graph_info] = graph;
|
|
|
|
|
return graph->graph_id();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
session_->EraseValueNodeTensor(tensors_mask, input_tensors);
|
|
|
|
|
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session_);
|
|
|
|
|
return session_->GetGraph(graph_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// wait for allreduce
|
|
|
|
|
for (auto &tensor : *input_tensors) {
|
|
|
|
|
if (tensor->NeedWaitDevice()) {
|
|
|
|
|
tensor->WaitDevice();
|
|
|
|
|
}
|
|
|
|
|
KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
|
|
|
|
|
auto iter = run_op_graphs_.find(graph_info);
|
|
|
|
|
if (iter == run_op_graphs_.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// run op
|
|
|
|
|
auto graph = run_op_graphs_[graph_info];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
session_->RunOpRemoveNopNode(graph);
|
|
|
|
|
|
|
|
|
|
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
|
|
|
|
|
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(actor_set);
|
|
|
|
|
GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep);
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
} // namespace runtime
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|