|
|
|
@ -101,6 +101,7 @@
|
|
|
|
|
#include "graph/common/local_context.h"
|
|
|
|
|
#include "graph/common/omg_util.h"
|
|
|
|
|
#include "common/formats/utils/formats_trans_utils.h"
|
|
|
|
|
#include "register/custom_pass_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
const char *const kSummary = "Summary";
|
|
|
|
@ -765,10 +766,24 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GraphManager::RunCustomPass(const GraphNodePtr &graph_node) {
|
|
|
|
|
ConstGraphPtr const_graph = graph_node->GetGraph();
|
|
|
|
|
auto comp_graph = GraphUtils::GetComputeGraph(*const_graph);
|
|
|
|
|
GE_DUMP(comp_graph, "RunCustomPassBegin");
|
|
|
|
|
|
|
|
|
|
GE_TIMESTAMP_START(RunCustomPass);
|
|
|
|
|
GraphPtr graph = std::const_pointer_cast<Graph>(const_graph);
|
|
|
|
|
GE_CHK_STATUS_RET(CustomPassHelper::Instance().Run(graph), "Graph[%s] run custom pass fail.",
|
|
|
|
|
comp_graph->GetName().c_str());
|
|
|
|
|
GE_TIMESTAMP_END(RunCustomPass, "GraphBuilder::RunCustomPass");
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs,
|
|
|
|
|
GeRootModelPtr &ge_root_model, uint64_t session_id) {
|
|
|
|
|
GE_CHECK_NOTNULL(graph_node);
|
|
|
|
|
GE_CHECK_NOTNULL(graph_node->GetGraph());
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(RunCustomPass(graph_node));
|
|
|
|
|
auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph());
|
|
|
|
|
GE_CHECK_NOTNULL(compute_graph);
|
|
|
|
|
compute_graph->SetSessionID(session_id);
|
|
|
|
|