|
|
|
@ -14,6 +14,9 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "include/api/context.h"
|
|
|
|
|
#include <any>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
#include "utils/log_adapter.h"
|
|
|
|
|
|
|
|
|
|
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
|
|
|
|
@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
|
|
|
|
|
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
template <class T>
|
|
|
|
|
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
|
|
|
|
|
auto iter = context->params.find(key);
|
|
|
|
|
if (iter == context->params.end()) {
|
|
|
|
|
return T();
|
|
|
|
|
struct Context::Data {
|
|
|
|
|
std::map<std::string, std::any> params;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Context::Context() : data(std::make_shared<Data>()) {}
|
|
|
|
|
|
|
|
|
|
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
|
|
|
|
|
static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
|
|
|
|
|
static U empty_result;
|
|
|
|
|
if (context == nullptr || context->data == nullptr) {
|
|
|
|
|
return empty_result;
|
|
|
|
|
}
|
|
|
|
|
auto iter = context->data->params.find(key);
|
|
|
|
|
if (iter == context->data->params.end()) {
|
|
|
|
|
return empty_result;
|
|
|
|
|
}
|
|
|
|
|
const std::any &value = iter->second;
|
|
|
|
|
if (value.type() != typeid(T)) {
|
|
|
|
|
return T();
|
|
|
|
|
if (value.type() != typeid(U)) {
|
|
|
|
|
return empty_result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::any_cast<T>(value);
|
|
|
|
|
return std::any_cast<const U &>(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
|
|
|
|
@ -47,22 +60,31 @@ std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
|
|
|
|
|
return g_context;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
|
|
|
|
|
void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
|
|
|
|
|
auto global_context = GetGlobalContext();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(global_context);
|
|
|
|
|
global_context->params[kGlobalContextDeviceTarget] = device_target;
|
|
|
|
|
if (global_context->data == nullptr) {
|
|
|
|
|
global_context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(global_context->data);
|
|
|
|
|
}
|
|
|
|
|
global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GlobalContext::GetGlobalDeviceTarget() {
|
|
|
|
|
std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
|
|
|
|
|
auto global_context = GetGlobalContext();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(global_context);
|
|
|
|
|
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
|
|
|
|
|
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
|
|
|
|
|
return StringToChar(ref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
|
|
|
|
|
auto global_context = GetGlobalContext();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(global_context);
|
|
|
|
|
global_context->params[kGlobalContextDeviceID] = device_id;
|
|
|
|
|
if (global_context->data == nullptr) {
|
|
|
|
|
global_context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(global_context->data);
|
|
|
|
|
}
|
|
|
|
|
global_context->data->params[kGlobalContextDeviceID] = device_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t GlobalContext::GetGlobalDeviceID() {
|
|
|
|
@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() {
|
|
|
|
|
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
|
|
|
|
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
|
|
|
|
|
if (context->data == nullptr) {
|
|
|
|
|
context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context->data);
|
|
|
|
|
}
|
|
|
|
|
context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
|
|
|
|
|
std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
|
|
|
|
|
const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
|
|
|
|
|
return StringToChar(ref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
|
|
|
|
|
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
context->params[kModelOptionInputFormat] = format;
|
|
|
|
|
if (context->data == nullptr) {
|
|
|
|
|
context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context->data);
|
|
|
|
|
}
|
|
|
|
|
context->data->params[kModelOptionInputFormat] = CharToString(format);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
|
|
|
|
|
std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
return GetValue<std::string>(context, kModelOptionInputFormat);
|
|
|
|
|
const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
|
|
|
|
|
return StringToChar(ref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
|
|
|
|
|
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
context->params[kModelOptionInputShape] = shape;
|
|
|
|
|
if (context->data == nullptr) {
|
|
|
|
|
context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context->data);
|
|
|
|
|
}
|
|
|
|
|
context->data->params[kModelOptionInputShape] = CharToString(shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
|
|
|
|
|
std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
return GetValue<std::string>(context, kModelOptionInputShape);
|
|
|
|
|
const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
|
|
|
|
|
return StringToChar(ref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
context->params[kModelOptionOutputType] = output_type;
|
|
|
|
|
if (context->data == nullptr) {
|
|
|
|
|
context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context->data);
|
|
|
|
|
}
|
|
|
|
|
context->data->params[kModelOptionOutputType] = output_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
|
|
|
|
@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &contex
|
|
|
|
|
return GetValue<enum DataType>(context, kModelOptionOutputType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
|
|
|
|
|
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
context->params[kModelOptionPrecisionMode] = precision_mode;
|
|
|
|
|
if (context->data == nullptr) {
|
|
|
|
|
context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context->data);
|
|
|
|
|
}
|
|
|
|
|
context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
|
|
|
|
|
std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
return GetValue<std::string>(context, kModelOptionPrecisionMode);
|
|
|
|
|
const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
|
|
|
|
|
return StringToChar(ref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
|
|
|
|
const std::string &op_select_impl_mode) {
|
|
|
|
|
const std::vector<char> &op_select_impl_mode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
|
|
|
|
|
if (context->data == nullptr) {
|
|
|
|
|
context->data = std::make_shared<Data>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context->data);
|
|
|
|
|
}
|
|
|
|
|
context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
|
|
|
|
|
std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
|
|
|
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
|
|
|
|
|
const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
|
|
|
|
|
return StringToChar(ref);
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|