gpu inference config

pull/12802/head
wilfChen 4 years ago
parent fe7652e87f
commit 156e8e92e9

@ -22,6 +22,7 @@
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"
namespace mindspore {
class InputAndOutput;
@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell<GraphCell> {
explicit GraphCell(Graph &&);
explicit GraphCell(const std::shared_ptr<Graph> &);
void SetContext(const std::shared_ptr<Context> &context);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
std::vector<MSTensor> GetInputs();

@ -81,6 +81,9 @@ struct MS_API ModelContext : public Context {
static inline void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context);
static inline void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode);
static inline std::string GetGpuTrtInferMode(const std::shared_ptr<Context> &context);
private:
// api without std::string
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
@ -101,6 +104,9 @@ struct MS_API ModelContext : public Context {
static void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
static std::vector<char> GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context);
static void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::vector<char> &gpu_trt_infer_mode);
static std::vector<char> GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context);
};
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
@ -155,5 +161,12 @@ void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &con
std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context) {
return CharToString(GetFusionSwitchConfigPathChar(context));
}
void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode) {
SetGpuTrtInferMode(context, StringToChar(gpu_trt_infer_mode));
}
std::string ModelContext::GetGpuTrtInferMode(const std::shared_ptr<Context> &context) {
return CharToString(GetGpuTrtInferModeChar(context));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

@ -78,6 +78,8 @@ GraphCell::GraphCell(Graph &&graph)
executor_->SetGraph(graph_);
}
void GraphCell::SetContext(const std::shared_ptr<Context> &context) { return executor_->SetContext(context); }
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->Run(inputs, outputs);

@ -31,6 +31,8 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path";
// "False": Inference with native backend, "True": Inference with Tensor-RT engine, default as "False"
constexpr auto kModelOptionGpuTrtInferMode = "mindspore.option.gpu_trt_infer_mode";
namespace mindspore {
struct Context::Data {
@ -217,4 +219,20 @@ std::vector<char> ModelContext::GetFusionSwitchConfigPathChar(const std::shared_
const std::string &ref = GetValue<std::string>(context, KModelOptionFusionSwitchCfgPath);
return StringToChar(ref);
}
void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context,
const std::vector<char> &gpu_trt_infer_mode) {
MS_EXCEPTION_IF_NULL(context);
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionGpuTrtInferMode] = CharToString(gpu_trt_infer_mode);
}
std::vector<char> ModelContext::GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
const std::string &ref = GetValue<std::string>(context, kModelOptionGpuTrtInferMode);
return StringToChar(ref);
}
} // namespace mindspore

@ -51,7 +51,10 @@ Status GPUGraphImpl::InitEnv() {
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true);
auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_);
if (enable_trt == "True") {
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true);
}
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
if (session_impl_ == nullptr) {

@ -29,11 +29,12 @@
namespace mindspore {
class GraphCell::GraphImpl {
public:
GraphImpl() = default;
GraphImpl() : graph_(nullptr), graph_context_(nullptr) {}
virtual ~GraphImpl() = default;
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; }
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Load() = 0;
@ -43,6 +44,7 @@ class GraphCell::GraphImpl {
protected:
std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> graph_context_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H

@ -70,6 +70,7 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
graph_cell->SetContext(model_context_);
auto ret = ModelImpl::Load(graph_cell);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
@ -95,6 +96,7 @@ Status MsModel::Build() {
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
graph_cell->SetContext(model_context_);
auto ret = ModelImpl::Load(graph_cell);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";

Loading…
Cancel
Save