From 539d88552a1a17611592ad0bd6ae99a0c17845ca Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Wed, 3 Feb 2021 14:43:42 +0800 Subject: [PATCH] check modeltype Signed-off-by: zhoufeng --- include/api/context.h | 1 + mindspore/ccsrc/cxx_api/model/model.cc | 26 +++++++++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/include/api/context.h b/include/api/context.h index 0aea49dd99..b1b39b91b0 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -25,6 +25,7 @@ namespace mindspore { constexpr auto kDeviceTypeAscend310 = "Ascend310"; constexpr auto kDeviceTypeAscend910 = "Ascend910"; +constexpr auto kDeviceTypeGPU = "GPU"; struct MS_API Context { virtual ~Context() = default; diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 88d364f7f2..a0b2237dc2 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -20,6 +20,13 @@ #include "utils/utils.h" namespace mindspore { +namespace { +const std::map> kSupportedModelMap = { + {kDeviceTypeAscend310, {kOM, kMindIR}}, + {kDeviceTypeAscend910, {kMindIR}}, + {kDeviceTypeGPU, {kMindIR}}, +}; +} Status Model::Build() { MS_EXCEPTION_IF_NULL(impl_); return impl_->Build(); @@ -61,8 +68,21 @@ Model::Model(const std::vector &network, const std::shared_ptr Model::~Model() {} -bool Model::CheckModelSupport(const std::string &device_type, ModelType) { - return Factory::Instance().CheckModelSupport(device_type); -} +bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { + if (!Factory::Instance().CheckModelSupport(device_type)) { + return false; + } + + auto first_iter = kSupportedModelMap.find(device_type); + if (first_iter == kSupportedModelMap.end()) { + return false; + } + + auto secend_iter = first_iter->second.find(model_type); + if (secend_iter == first_iter->second.end()) { + return false; + } + return true; +} } // namespace mindspore