From 156e8e92e9f46386ddcef0939097e95292f57a73 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Thu, 4 Mar 2021 15:08:26 +0800 Subject: [PATCH] gpu inference config --- include/api/cell.h | 2 ++ include/api/context.h | 13 +++++++++++++ mindspore/ccsrc/cxx_api/cell.cc | 2 ++ mindspore/ccsrc/cxx_api/context.cc | 18 ++++++++++++++++++ .../ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc | 5 ++++- mindspore/ccsrc/cxx_api/graph/graph_impl.h | 4 +++- mindspore/ccsrc/cxx_api/model/ms/ms_model.cc | 2 ++ 7 files changed, 44 insertions(+), 2 deletions(-) diff --git a/include/api/cell.h b/include/api/cell.h index 3039fa816b..8b9580af49 100644 --- a/include/api/cell.h +++ b/include/api/cell.h @@ -22,6 +22,7 @@ #include "include/api/status.h" #include "include/api/types.h" #include "include/api/graph.h" +#include "include/api/context.h" namespace mindspore { class InputAndOutput; @@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell { explicit GraphCell(Graph &&); explicit GraphCell(const std::shared_ptr &); + void SetContext(const std::shared_ptr &context); const std::shared_ptr &GetGraph() const { return graph_; } Status Run(const std::vector &inputs, std::vector *outputs) override; std::vector GetInputs(); diff --git a/include/api/context.h b/include/api/context.h index 1d8852bdee..9105f73915 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -81,6 +81,9 @@ struct MS_API ModelContext : public Context { static inline void SetFusionSwitchConfigPath(const std::shared_ptr &context, const std::string &cfg_path); static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr &context); + static inline void SetGpuTrtInferMode(const std::shared_ptr &context, const std::string &gpu_trt_infer_mode); + static inline std::string GetGpuTrtInferMode(const std::shared_ptr &context); + private: // api without std::string static void SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path); @@ -101,6 +104,9 @@ struct MS_API ModelContext : public Context { static void SetFusionSwitchConfigPath(const std::shared_ptr &context, const std::vector &cfg_path); static std::vector GetFusionSwitchConfigPathChar(const std::shared_ptr &context); + + static void SetGpuTrtInferMode(const std::shared_ptr &context, const std::vector &gpu_trt_infer_mode); + static std::vector GetGpuTrtInferModeChar(const std::shared_ptr &context); }; void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { @@ -155,5 +161,12 @@ void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr &con std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr &context) { return CharToString(GetFusionSwitchConfigPathChar(context)); } + +void ModelContext::SetGpuTrtInferMode(const std::shared_ptr &context, const std::string &gpu_trt_infer_mode) { + SetGpuTrtInferMode(context, StringToChar(gpu_trt_infer_mode)); +} +std::string ModelContext::GetGpuTrtInferMode(const std::shared_ptr &context) { + return CharToString(GetGpuTrtInferModeChar(context)); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/mindspore/ccsrc/cxx_api/cell.cc b/mindspore/ccsrc/cxx_api/cell.cc index ebf3a4706e..d7d5b53b30 100644 --- a/mindspore/ccsrc/cxx_api/cell.cc +++ b/mindspore/ccsrc/cxx_api/cell.cc @@ -78,6 +78,8 @@ GraphCell::GraphCell(Graph &&graph) executor_->SetGraph(graph_); } +void GraphCell::SetContext(const std::shared_ptr &context) { return executor_->SetContext(context); } + Status GraphCell::Run(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(executor_); return executor_->Run(inputs, outputs); diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index b8ef34b955..3f5597098d 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -31,6 +31,8 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16" constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path"; +// "False": Inference with native backend, "True": Inference with Tensor-RT engine, default as "False" +constexpr auto kModelOptionGpuTrtInferMode = "mindspore.option.gpu_trt_infer_mode"; namespace mindspore { struct Context::Data { @@ -217,4 +219,20 @@ std::vector ModelContext::GetFusionSwitchConfigPathChar(const std::shared_ const std::string &ref = GetValue(context, KModelOptionFusionSwitchCfgPath); return StringToChar(ref); } + +void ModelContext::SetGpuTrtInferMode(const std::shared_ptr &context, + const std::vector &gpu_trt_infer_mode) { + MS_EXCEPTION_IF_NULL(context); + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionGpuTrtInferMode] = CharToString(gpu_trt_infer_mode); +} + +std::vector ModelContext::GetGpuTrtInferModeChar(const std::shared_ptr &context) { + MS_EXCEPTION_IF_NULL(context); + const std::string &ref = GetValue(context, kModelOptionGpuTrtInferMode); + return StringToChar(ref); +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc index 749a4bc08b..d7872894f8 100644 --- a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -51,7 +51,10 @@ Status GPUGraphImpl::InitEnv() { ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); ms_context->set_param(MS_CTX_DEVICE_TARGET, kGPUDevice); - ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, true); + auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_); + if (enable_trt == "True") { + ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, true); + } session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); if (session_impl_ == nullptr) { diff --git a/mindspore/ccsrc/cxx_api/graph/graph_impl.h b/mindspore/ccsrc/cxx_api/graph/graph_impl.h index 42c843225d..3b46e408da 100644 --- a/mindspore/ccsrc/cxx_api/graph/graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/graph_impl.h @@ -29,11 +29,12 @@ namespace mindspore { class GraphCell::GraphImpl { public: - GraphImpl() = default; + GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} virtual ~GraphImpl() = default; std::shared_ptr &MutableGraphData() const { return graph_->graph_data_; } void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } + void SetContext(const std::shared_ptr &context) { graph_context_ = context; } virtual Status Run(const std::vector &inputs, std::vector *outputs) = 0; virtual Status Load() = 0; @@ -43,6 +44,7 @@ class GraphCell::GraphImpl { protected: std::shared_ptr graph_; + std::shared_ptr graph_context_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index 5a4366d0b7..9a07105a22 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -70,6 +70,7 @@ std::shared_ptr MsModel::GenerateGraphCell(const std::vector(graph); MS_EXCEPTION_IF_NULL(graph_cell); + graph_cell->SetContext(model_context_); auto ret = ModelImpl::Load(graph_cell); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed."; @@ -95,6 +96,7 @@ Status MsModel::Build() { MS_EXCEPTION_IF_NULL(graph); auto graph_cell = std::make_shared(graph); MS_EXCEPTION_IF_NULL(graph_cell); + graph_cell->SetContext(model_context_); auto ret = ModelImpl::Load(graph_cell); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed.";