diff --git a/include/api/context.h b/include/api/context.h index 0aea49dd99..b1b39b91b0 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -25,6 +25,7 @@ namespace mindspore { constexpr auto kDeviceTypeAscend310 = "Ascend310"; constexpr auto kDeviceTypeAscend910 = "Ascend910"; +constexpr auto kDeviceTypeGPU = "GPU"; struct MS_API Context { virtual ~Context() = default; diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 88d364f7f2..a0b2237dc2 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -20,6 +20,13 @@ #include "utils/utils.h" namespace mindspore { +namespace { +const std::map> kSupportedModelMap = { + {kDeviceTypeAscend310, {kOM, kMindIR}}, + {kDeviceTypeAscend910, {kMindIR}}, + {kDeviceTypeGPU, {kMindIR}}, +}; +} Status Model::Build() { MS_EXCEPTION_IF_NULL(impl_); return impl_->Build(); @@ -61,8 +68,21 @@ Model::Model(const std::vector &network, const std::shared_ptr Model::~Model() {} -bool Model::CheckModelSupport(const std::string &device_type, ModelType) { - return Factory::Instance().CheckModelSupport(device_type); -} +bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { + if (!Factory::Instance().CheckModelSupport(device_type)) { + return false; + } + + auto first_iter = kSupportedModelMap.find(device_type); + if (first_iter == kSupportedModelMap.end()) { + return false; + } + + auto secend_iter = first_iter->second.find(model_type); + if (secend_iter == first_iter->second.end()) { + return false; + } + return true; +} } // namespace mindspore