diff --git a/include/api/cell.h b/include/api/cell.h index d6b815fd74..368cf25d34 100644 --- a/include/api/cell.h +++ b/include/api/cell.h @@ -25,6 +25,7 @@ namespace mindspore { class InputAndOutput; +class Context; using Input = InputAndOutput; using Output = InputAndOutput; @@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell { explicit GraphCell(Graph &&); explicit GraphCell(const std::shared_ptr &); + void SetContext(const std::shared_ptr &context); const std::shared_ptr &GetGraph() const { return graph_; } Status Run(const std::vector &inputs, std::vector *outputs) override; std::vector GetInputs(); diff --git a/mindspore/ccsrc/backend/session/gpu_inference_session.cc b/mindspore/ccsrc/backend/session/gpu_inference_session.cc index c1091d7c81..ccf682063f 100644 --- a/mindspore/ccsrc/backend/session/gpu_inference_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_inference_session.cc @@ -212,6 +212,5 @@ std::string GpuInferenceSession::InputsInfo(const std::vector &par } return graph + " " + actual; } - } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/cell.cc b/mindspore/ccsrc/cxx_api/cell.cc index 87ee6b7978..f031170f69 100644 --- a/mindspore/ccsrc/cxx_api/cell.cc +++ b/mindspore/ccsrc/cxx_api/cell.cc @@ -73,6 +73,18 @@ GraphCell::GraphCell(const std::shared_ptr &graph) : graph_(graph) { MS_E GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared(graph)) { MS_EXCEPTION_IF_NULL(graph_); } +void GraphCell::SetContext(const std::shared_ptr &context) { + if (executor_ == nullptr) { + executor_ = Factory::Instance().Create(g_device_target); + if (executor_ == nullptr) { + MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; + return; + } + executor_->SetGraph(graph_); + } + executor_->SetContext(context); +} + Status GraphCell::Run(const std::vector &inputs, std::vector *outputs) { if (executor_ == nullptr) { executor_ = Factory::Instance().Create(g_device_target); diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc index 263cbd3ee9..2891c3cdbb 100644 --- a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -54,7 +54,17 @@ Status GPUGraphImpl::InitEnv() { ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); ms_context->set_param(MS_CTX_DEVICE_TARGET, kGPUDevice); - ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, false); + + auto &device_infos = graph_context_->MutableDeviceInfo(); + if (device_infos.size() != 1) { + return kMCDeviceError; + } + auto gpu_info = device_infos[0]->Cast(); + if (gpu_info == nullptr) { + return kMCDeviceError; + } + auto enable_trt = gpu_info->GetGpuTrtInferMode(); + ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, enable_trt); session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); if (session_impl_ == nullptr) { diff --git a/mindspore/ccsrc/cxx_api/graph/graph_impl.h b/mindspore/ccsrc/cxx_api/graph/graph_impl.h index 23fea0bea9..2b678e6340 100644 --- a/mindspore/ccsrc/cxx_api/graph/graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/graph_impl.h @@ -29,11 +29,12 @@ namespace mindspore { class GraphCell::GraphImpl { public: - GraphImpl() : graph_(nullptr) {} + GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} virtual ~GraphImpl() = default; std::shared_ptr &MutableGraphData() const { return graph_->graph_data_; } void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } + void SetContext(const std::shared_ptr &context) { graph_context_ = context; } virtual Status Run(const std::vector &inputs, std::vector *outputs) = 0; virtual Status Load(uint32_t device_id) = 0; @@ -43,6 +44,7 @@ class GraphCell::GraphImpl { protected: std::shared_ptr graph_; + std::shared_ptr graph_context_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index da2aec4634..ed34b3979d 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -74,6 +74,7 @@ std::shared_ptr MsModel::GenerateGraphCell(const std::vector(graph); MS_EXCEPTION_IF_NULL(graph_cell); + graph_cell->SetContext(model_context_); auto ret = ModelImpl::Load(graph_cell, GetDeviceID()); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed."; @@ -99,6 +100,7 @@ Status MsModel::Build() { MS_EXCEPTION_IF_NULL(graph); auto graph_cell = std::make_shared(graph); MS_EXCEPTION_IF_NULL(graph_cell); + graph_cell->SetContext(model_context_); auto ret = ModelImpl::Load(graph_cell, GetDeviceID()); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed."; diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 6ea55c9ba6..b39d75b522 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { set_param(MS_CTX_ENABLE_GRAPH_KERNEL, false); set_param(MS_CTX_ENABLE_SPARSE, false); set_param(MS_CTX_ENABLE_PARALLEL_SPLIT, false); + set_param(MS_CTX_ENABLE_INFER_OPT, false); set_param(MS_CTX_GRAD_FOR_SCALAR, false); backend_policy_ = policy_map_[policy];