From 504f2158001e9a13662929536624f6835e168b9f Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Fri, 19 Feb 2021 17:09:40 +0800 Subject: [PATCH] api support dual abi Signed-off-by: zhoufeng --- include/api/context.h | 100 ++++++++++++--- include/api/dual_abi_helper.h | 26 ++++ include/api/graph.h | 1 - include/api/model.h | 9 +- include/api/serialization.h | 10 +- include/api/status.h | 69 ++++++---- include/api/types.h | 40 +++++- mindspore/ccsrc/cxx_api/context.cc | 119 ++++++++++++----- mindspore/ccsrc/cxx_api/model/model.cc | 7 +- mindspore/ccsrc/cxx_api/serialization.cc | 9 +- mindspore/ccsrc/cxx_api/types.cc | 20 +-- .../ccsrc/minddata/dataset/CMakeLists.txt | 3 +- .../ccsrc/minddata/dataset/include/tensor.h | 3 + mindspore/core/utils/status.cc | 120 +++++++++++++++--- mindspore/lite/src/cxx_api/model/model.cc | 2 +- mindspore/lite/src/cxx_api/serialization.cc | 2 +- mindspore/lite/src/cxx_api/types.cc | 19 ++- tests/ut/cpp/cxx_api/context_test.cc | 6 - 18 files changed, 425 insertions(+), 140 deletions(-) create mode 100644 include/api/dual_abi_helper.h diff --git a/include/api/context.h b/include/api/context.h index b1b39b91b0..90dfa408d6 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -16,11 +16,11 @@ #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H -#include -#include #include #include +#include #include "include/api/types.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { constexpr auto kDeviceTypeAscend310 = "Ascend310"; @@ -28,38 +28,108 @@ constexpr auto kDeviceTypeAscend910 = "Ascend910"; constexpr auto kDeviceTypeGPU = "GPU"; struct MS_API Context { + public: + Context(); virtual ~Context() = default; - std::map params; + struct Data; + std::shared_ptr data; }; struct MS_API GlobalContext : public Context { + public: static std::shared_ptr GetGlobalContext(); - static void SetGlobalDeviceTarget(const std::string &device_target); - static std::string GetGlobalDeviceTarget(); + static inline void SetGlobalDeviceTarget(const std::string &device_target); + static inline std::string GetGlobalDeviceTarget(); static void SetGlobalDeviceID(const uint32_t &device_id); static uint32_t GetGlobalDeviceID(); + + private: + // api without std::string + static void SetGlobalDeviceTarget(const std::vector &device_target); + static std::vector GetGlobalDeviceTargetChar(); }; struct MS_API ModelContext : public Context { - static void SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path); - static std::string GetInsertOpConfigPath(const std::shared_ptr &context); + public: + static inline void SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path); + static inline std::string GetInsertOpConfigPath(const std::shared_ptr &context); - static void SetInputFormat(const std::shared_ptr &context, const std::string &format); - static std::string GetInputFormat(const std::shared_ptr &context); + static inline void SetInputFormat(const std::shared_ptr &context, const std::string &format); + static inline std::string GetInputFormat(const std::shared_ptr &context); - static void SetInputShape(const std::shared_ptr &context, const std::string &shape); - static std::string GetInputShape(const std::shared_ptr &context); + static inline void SetInputShape(const std::shared_ptr &context, const std::string &shape); + static inline std::string GetInputShape(const std::shared_ptr &context); static void SetOutputType(const std::shared_ptr &context, enum DataType output_type); static enum DataType GetOutputType(const std::shared_ptr &context); - static void SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode); - static std::string GetPrecisionMode(const std::shared_ptr &context); + static inline void SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode); + static inline std::string GetPrecisionMode(const std::shared_ptr &context); + + static inline void SetOpSelectImplMode(const std::shared_ptr &context, + const std::string &op_select_impl_mode); + static inline std::string GetOpSelectImplMode(const std::shared_ptr &context); + + private: + // api without std::string + static void SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path); + static std::vector GetInsertOpConfigPathChar(const std::shared_ptr &context); + + static void SetInputFormat(const std::shared_ptr &context, const std::vector &format); + static std::vector GetInputFormatChar(const std::shared_ptr &context); + + static void SetInputShape(const std::shared_ptr &context, const std::vector &shape); + static std::vector GetInputShapeChar(const std::shared_ptr &context); + + static void SetPrecisionMode(const std::shared_ptr &context, const std::vector &precision_mode); + static std::vector GetPrecisionModeChar(const std::shared_ptr &context); - static void SetOpSelectImplMode(const std::shared_ptr &context, const std::string &op_select_impl_mode); - static std::string GetOpSelectImplMode(const std::shared_ptr &context); + static void SetOpSelectImplMode(const std::shared_ptr &context, + const std::vector &op_select_impl_mode); + static std::vector GetOpSelectImplModeChar(const std::shared_ptr &context); }; + +void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { + SetGlobalDeviceTarget(StringToChar(device_target)); +} +std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); } + +void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path) { + SetInsertOpConfigPath(context, StringToChar(cfg_path)); +} +std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr &context) { + return CharToString(GetInsertOpConfigPathChar(context)); +} + +void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::string &format) { + SetInputFormat(context, StringToChar(format)); +} +std::string ModelContext::GetInputFormat(const std::shared_ptr &context) { + return CharToString(GetInputFormatChar(context)); +} + +void ModelContext::SetInputShape(const std::shared_ptr &context, const std::string &shape) { + SetInputShape(context, StringToChar(shape)); +} +std::string ModelContext::GetInputShape(const std::shared_ptr &context) { + return CharToString(GetInputShapeChar(context)); +} + +void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode) { + SetPrecisionMode(context, StringToChar(precision_mode)); +} +std::string ModelContext::GetPrecisionMode(const std::shared_ptr &context) { + return CharToString(GetPrecisionModeChar(context)); +} + +void ModelContext::SetOpSelectImplMode(const std::shared_ptr &context, + const std::string &op_select_impl_mode) { + SetOpSelectImplMode(context, StringToChar(op_select_impl_mode)); +} +std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr &context) { + return CharToString(GetOpSelectImplModeChar(context)); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/include/api/dual_abi_helper.h b/include/api/dual_abi_helper.h new file mode 100644 index 0000000000..6bf9c6eec8 --- /dev/null +++ b/include/api/dual_abi_helper.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ +#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ + +#include +#include + +namespace mindspore { +inline std::vector StringToChar(const std::string &s) { return std::vector(s.begin(), s.end()); } +inline std::string CharToString(const std::vector &c) { return std::string(c.begin(), c.end()); } +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ diff --git a/include/api/graph.h b/include/api/graph.h index a9288eb5a1..892f60495a 100644 --- a/include/api/graph.h +++ b/include/api/graph.h @@ -17,7 +17,6 @@ #define MINDSPORE_INCLUDE_API_GRAPH_H #include -#include #include #include #include diff --git a/include/api/model.h b/include/api/model.h index 8d401085eb..78f202fae2 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -25,6 +25,7 @@ #include "include/api/types.h" #include "include/api/graph.h" #include "include/api/cell.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { class ModelImpl; @@ -46,10 +47,16 @@ class MS_API Model { std::vector GetInputs(); std::vector GetOutputs(); - static bool CheckModelSupport(const std::string &device_type, ModelType model_type); + static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type); private: + // api without std::string + static bool CheckModelSupport(const std::vector &device_type, ModelType model_type); std::shared_ptr impl_; }; + +bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { + return CheckModelSupport(StringToChar(device_type), model_type); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/include/api/serialization.h b/include/api/serialization.h index 2c34b826d3..c5fb61eb07 100644 --- a/include/api/serialization.h +++ b/include/api/serialization.h @@ -24,16 +24,24 @@ #include "include/api/types.h" #include "include/api/model.h" #include "include/api/graph.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { class MS_API Serialization { public: static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type); - static Graph LoadModel(const std::string &file, ModelType model_type); + inline static Graph LoadModel(const std::string &file, ModelType model_type); static Status LoadCheckPoint(const std::string &ckpt_file, std::map *parameters); static Status SetParameters(const std::map ¶meters, Model *model); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); + + private: + static Graph LoadModel(const std::vector &file, ModelType model_type); }; + +Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { + return LoadModel(StringToChar(file), model_type); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H diff --git a/include/api/status.h b/include/api/status.h index 131a15372c..3a5c7c4e01 100644 --- a/include/api/status.h +++ b/include/api/status.h @@ -16,9 +16,13 @@ #ifndef MINDSPORE_INCLUDE_API_STATUS_H #define MINDSPORE_INCLUDE_API_STATUS_H +#include #include +#include #include #include +#include "include/api/dual_abi_helper.h" +#include "include/api/types.h" namespace mindspore { enum CompCode : uint32_t { @@ -100,46 +104,61 @@ enum StatusCode : uint32_t { kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */ }; -class Status { +class MS_API Status { public: - Status() : status_code_(kSuccess), line_of_code_(-1) {} - Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) - : status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {} - Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); + Status(); + inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit) + inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); ~Status() = default; - enum StatusCode StatusCode() const { return status_code_; } - const std::string &ToString() const { return status_msg_; } + enum StatusCode StatusCode() const; + inline std::string ToString() const; - int GetLineOfCode() const { return line_of_code_; } - const std::string &GetErrDescription() const { return status_msg_; } - const std::string &SetErrDescription(const std::string &err_description); + int GetLineOfCode() const; + inline std::string GetErrDescription() const; + inline std::string SetErrDescription(const std::string &err_description); friend std::ostream &operator<<(std::ostream &os, const Status &s); - bool operator==(const Status &other) const { return status_code_ == other.status_code_; } - bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } - bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } - bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } + bool operator==(const Status &other) const; + bool operator==(enum StatusCode other_code) const; + bool operator!=(const Status &other) const; + bool operator!=(enum StatusCode other_code) const; - explicit operator bool() const { return (status_code_ == kSuccess); } - explicit operator int() const { return static_cast(status_code_); } + explicit operator bool() const; + explicit operator int() const; - static Status OK() { return Status(StatusCode::kSuccess); } + static Status OK(); - bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); } + bool IsOk() const; - bool IsError() const { return !IsOk(); } + bool IsError() const; - static std::string CodeAsString(enum StatusCode c); + static inline std::string CodeAsString(enum StatusCode c); private: - enum StatusCode status_code_; - std::string status_msg_; - int line_of_code_; - std::string file_name_; - std::string err_description_; + // api without std::string + explicit Status(enum StatusCode status_code, const std::vector &status_msg); + Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector &extra); + std::vector ToCString() const; + std::vector GetErrDescriptionChar() const; + std::vector SetErrDescription(const std::vector &err_description); + static std::vector CodeAsCString(enum StatusCode c); + + struct Data; + std::shared_ptr data_; }; + +Status::Status(enum StatusCode status_code, const std::string &status_msg) + : Status(status_code, StringToChar(status_msg)) {} +Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) + : Status(code, line_of_code, file_name, StringToChar(extra)) {} +std::string Status::ToString() const { return CharToString(ToCString()); } +std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); } +std::string Status::SetErrDescription(const std::string &err_description) { + return CharToString(SetErrDescription(StringToChar(err_description))); +} +std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_STATUS_H diff --git a/include/api/types.h b/include/api/types.h index 0f4503e122..482159d8c0 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -21,6 +21,7 @@ #include #include #include "include/api/data_type.h" +#include "include/api/dual_abi_helper.h" #ifdef _WIN32 #define MS_API __declspec(dllexport) @@ -42,18 +43,18 @@ class MS_API MSTensor { public: class Impl; - static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector &shape, - const void *data, size_t data_len) noexcept; - static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, - const void *data, size_t data_len) noexcept; + static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; + static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; MSTensor(); explicit MSTensor(const std::shared_ptr &impl); - MSTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, - size_t data_len); + inline MSTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, + size_t data_len); ~MSTensor(); - const std::string &Name() const; + inline std::string Name() const; enum DataType DataType() const; const std::vector &Shape() const; int64_t ElementNum() const; @@ -68,6 +69,15 @@ class MS_API MSTensor { bool operator==(std::nullptr_t) const; private: + // api without std::string + static MSTensor CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; + static MSTensor CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; + MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, + size_t data_len); + std::vector CharName() const; + friend class ModelImpl; explicit MSTensor(std::nullptr_t); std::shared_ptr impl_; @@ -92,5 +102,21 @@ class MS_API Buffer { class Impl; std::shared_ptr impl_; }; + +MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept { + return CreateTensor(StringToChar(name), type, shape, data, data_len); +} + +MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept { + return CreateRefTensor(StringToChar(name), type, shape, data, data_len); +} + +MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, + size_t data_len) + : MSTensor(StringToChar(name), type, shape, data, data_len) {} + +std::string MSTensor::Name() const { return CharToString(CharName()); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_TYPES_H diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index a9ea4055a0..d9679639f9 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -14,6 +14,9 @@ * limitations under the License. */ #include "include/api/context.h" +#include +#include +#include #include "utils/log_adapter.h" constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target"; @@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; namespace mindspore { -template -static T GetValue(const std::shared_ptr &context, const std::string &key) { - auto iter = context->params.find(key); - if (iter == context->params.end()) { - return T(); +struct Context::Data { + std::map params; +}; + +Context::Context() : data(std::make_shared()) {} + +template >> +static const U &GetValue(const std::shared_ptr &context, const std::string &key) { + static U empty_result; + if (context == nullptr || context->data == nullptr) { + return empty_result; + } + auto iter = context->data->params.find(key); + if (iter == context->data->params.end()) { + return empty_result; } const std::any &value = iter->second; - if (value.type() != typeid(T)) { - return T(); + if (value.type() != typeid(U)) { + return empty_result; } - return std::any_cast(value); + return std::any_cast(value); } std::shared_ptr GlobalContext::GetGlobalContext() { @@ -47,22 +60,31 @@ std::shared_ptr GlobalContext::GetGlobalContext() { return g_context; } -void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { +void GlobalContext::SetGlobalDeviceTarget(const std::vector &device_target) { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); - global_context->params[kGlobalContextDeviceTarget] = device_target; + if (global_context->data == nullptr) { + global_context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(global_context->data); + } + global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target); } -std::string GlobalContext::GetGlobalDeviceTarget() { +std::vector GlobalContext::GetGlobalDeviceTargetChar() { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); - return GetValue(global_context, kGlobalContextDeviceTarget); + const std::string &ref = GetValue(global_context, kGlobalContextDeviceTarget); + return StringToChar(ref); } void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); - global_context->params[kGlobalContextDeviceID] = device_id; + if (global_context->data == nullptr) { + global_context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(global_context->data); + } + global_context->data->params[kGlobalContextDeviceID] = device_id; } uint32_t GlobalContext::GetGlobalDeviceID() { @@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() { return GetValue(global_context, kGlobalContextDeviceID); } -void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path) { +void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionInsertOpCfgPath] = cfg_path; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path); } -std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr &context) { +std::vector ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionInsertOpCfgPath); + const std::string &ref = GetValue(context, kModelOptionInsertOpCfgPath); + return StringToChar(ref); } -void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::string &format) { +void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::vector &format) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionInputFormat] = format; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionInputFormat] = CharToString(format); } -std::string ModelContext::GetInputFormat(const std::shared_ptr &context) { +std::vector ModelContext::GetInputFormatChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionInputFormat); + const std::string &ref = GetValue(context, kModelOptionInputFormat); + return StringToChar(ref); } -void ModelContext::SetInputShape(const std::shared_ptr &context, const std::string &shape) { +void ModelContext::SetInputShape(const std::shared_ptr &context, const std::vector &shape) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionInputShape] = shape; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionInputShape] = CharToString(shape); } -std::string ModelContext::GetInputShape(const std::shared_ptr &context) { +std::vector ModelContext::GetInputShapeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionInputShape); + const std::string &ref = GetValue(context, kModelOptionInputShape); + return StringToChar(ref); } void ModelContext::SetOutputType(const std::shared_ptr &context, enum DataType output_type) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionOutputType] = output_type; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionOutputType] = output_type; } enum DataType ModelContext::GetOutputType(const std::shared_ptr &context) { @@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr &contex return GetValue(context, kModelOptionOutputType); } -void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode) { +void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionPrecisionMode] = precision_mode; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode); } -std::string ModelContext::GetPrecisionMode(const std::shared_ptr &context) { +std::vector ModelContext::GetPrecisionModeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionPrecisionMode); + const std::string &ref = GetValue(context, kModelOptionPrecisionMode); + return StringToChar(ref); } void ModelContext::SetOpSelectImplMode(const std::shared_ptr &context, - const std::string &op_select_impl_mode) { + const std::vector &op_select_impl_mode) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode); } -std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr &context) { +std::vector ModelContext::GetOpSelectImplModeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionOpSelectImplMode); + const std::string &ref = GetValue(context, kModelOptionOpSelectImplMode); + return StringToChar(ref); } } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index a0b2237dc2..f50fd0c3dc 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -68,12 +68,13 @@ Model::Model(const std::vector &network, const std::shared_ptr Model::~Model() {} -bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { - if (!Factory::Instance().CheckModelSupport(device_type)) { +bool Model::CheckModelSupport(const std::vector &device_type, ModelType model_type) { + std::string device_type_str = CharToString(device_type); + if (!Factory::Instance().CheckModelSupport(device_type_str)) { return false; } - auto first_iter = kSupportedModelMap.find(device_type); + auto first_iter = kSupportedModelMap.find(device_type_str); if (first_iter == kSupportedModelMap.end()) { return false; } diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc index ad41849d2e..331f6885e7 100644 --- a/mindspore/ccsrc/cxx_api/serialization.cc +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -84,17 +84,18 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; } -Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { +Graph Serialization::LoadModel(const std::vector &file, ModelType model_type) { + std::string file_path = CharToString(file); if (model_type == kMindIR) { - FuncGraphPtr anf_graph = LoadMindIR(file); + FuncGraphPtr anf_graph = LoadMindIR(file_path); if (anf_graph == nullptr) { MS_LOG(EXCEPTION) << "Load model failed."; } return Graph(std::make_shared(anf_graph, kMindIR)); } else if (model_type == kOM) { - Buffer data = ReadFile(file); + Buffer data = ReadFile(file_path); if (data.Data() == nullptr) { - MS_LOG(EXCEPTION) << "Read file " << file << " failed."; + MS_LOG(EXCEPTION) << "Read file " << file_path << " failed."; } return Graph(std::make_shared(data, kOM)); } diff --git a/mindspore/ccsrc/cxx_api/types.cc b/mindspore/ccsrc/cxx_api/types.cc index 38ecf4dee1..a3872c09b1 100644 --- a/mindspore/ccsrc/cxx_api/types.cc +++ b/mindspore/ccsrc/cxx_api/types.cc @@ -134,10 +134,11 @@ class TensorReferenceImpl : public MSTensor::Impl { std::vector shape_; }; -MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector &shape, +MSTensor MSTensor::CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { + std::string name_str = CharToString(name); try { - std::shared_ptr impl = std::make_shared(name, type, shape, data, data_len); + std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len); return MSTensor(impl); } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; @@ -148,10 +149,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con } } -MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector &shape, +MSTensor MSTensor::CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { + std::string name_str = CharToString(name); try { - std::shared_ptr impl = std::make_shared(name, type, shape, data, data_len); + std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len); return MSTensor(impl); } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; @@ -165,9 +167,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, MSTensor::MSTensor() : impl_(std::make_shared()) {} MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} MSTensor::MSTensor(const std::shared_ptr &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); } -MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, - size_t data_len) - : impl_(std::make_shared(name, type, shape, data, data_len)) {} +MSTensor::MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) + : impl_(std::make_shared(CharToString(name), type, shape, data, data_len)) {} MSTensor::~MSTensor() = default; bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } @@ -179,9 +181,9 @@ MSTensor MSTensor::Clone() const { return ret; } -const std::string &MSTensor::Name() const { +std::vector MSTensor::CharName() const { MS_EXCEPTION_IF_NULL(impl_); - return impl_->Name(); + return StringToChar(impl_->Name()); } enum DataType MSTensor::DataType() const { diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 51462e090e..21861b5ee3 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -296,8 +296,7 @@ else() endif() endif() -add_dependencies(_c_dataengine mindspore_shared_lib) -target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib) +target_link_libraries(_c_dataengine PRIVATE mindspore_core mindspore_shared_lib) if(USE_GLOG) target_link_libraries(_c_dataengine PRIVATE mindspore::glog) diff --git a/mindspore/ccsrc/minddata/dataset/include/tensor.h b/mindspore/ccsrc/minddata/dataset/include/tensor.h index 03be78fe1c..b4f517330e 100644 --- a/mindspore/ccsrc/minddata/dataset/include/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/include/tensor.h @@ -686,6 +686,9 @@ class Tensor { /// pointer to the end of the physical data unsigned char *data_end_ = nullptr; + /// shape for interpretation of YUV image + std::vector yuv_shape_; + private: friend class DETensor; diff --git a/mindspore/core/utils/status.cc b/mindspore/core/utils/status.cc index 8e7ec47eb2..43a386e343 100644 --- a/mindspore/core/utils/status.cc +++ b/mindspore/core/utils/status.cc @@ -24,16 +24,42 @@ #include namespace mindspore { -Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) { - status_code_ = code; - line_of_code_ = line_of_code; - file_name_ = std::string(file_name); - err_description_ = extra; +struct Status::Data { + enum StatusCode status_code = kSuccess; + std::string status_msg; + int line_of_code = -1; + std::string file_name; + std::string err_description; +}; + +Status::Status() : data_(std::make_shared()) {} + +Status::Status(enum StatusCode status_code, const std::vector &status_msg) : data_(std::make_shared()) { + if (data_ == nullptr) { + return; + } + + data_->status_msg = CharToString(status_msg); + data_->status_code = status_code; +} + +Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector &extra) + : data_(std::make_shared()) { + if (data_ == nullptr) { + return; + } + data_->status_code = code; + data_->line_of_code = line_of_code; + if (file_name != nullptr) { + data_->file_name = file_name; + } + data_->err_description = CharToString(extra); + std::ostringstream ss; #ifndef ENABLE_ANDROID ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". "; - if (!extra.empty()) { - ss << extra; + if (!data_->err_description.empty()) { + ss << data_->err_description; } ss << "\n"; #endif @@ -42,10 +68,38 @@ Status::Status(enum StatusCode code, int line_of_code, const char *file_name, co if (file_name != nullptr) { ss << "File : " << file_name << "\n"; } - status_msg_ = ss.str(); + data_->status_msg = ss.str(); +} + +enum StatusCode Status::StatusCode() const { + if (data_ == nullptr) { + return kSuccess; + } + return data_->status_code; +} + +std::vector Status::ToCString() const { + if (data_ == nullptr) { + return std::vector(); + } + return StringToChar(data_->status_msg); } -std::string Status::CodeAsString(enum StatusCode c) { +int Status::GetLineOfCode() const { + if (data_ == nullptr) { + return -1; + } + return data_->line_of_code; +} + +std::vector Status::GetErrDescriptionChar() const { + if (data_ == nullptr) { + return std::vector(); + } + return StringToChar(data_->status_msg); +} + +std::vector Status::CodeAsCString(enum StatusCode c) { static std::map info_map = {{kSuccess, "No error occurs."}, // Core {kCoreFailed, "Common error code."}, @@ -98,7 +152,7 @@ std::string Status::CodeAsString(enum StatusCode c) { {kLiteInferInvalid, "Invalid infer shape before runtime."}, {kLiteInputParamInvalid, "Invalid input param by user."}}; auto iter = info_map.find(c); - return iter == info_map.end() ? "Unknown error" : iter->second; + return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second); } std::ostream &operator<<(std::ostream &os, const Status &s) { @@ -106,22 +160,48 @@ std::ostream &operator<<(std::ostream &os, const Status &s) { return os; } -const std::string &Status::SetErrDescription(const std::string &err_description) { - err_description_ = err_description; +std::vector Status::SetErrDescription(const std::vector &err_description) { + if (data_ == nullptr) { + return std::vector(); + } + data_->err_description = CharToString(err_description); std::ostringstream ss; #ifndef ENABLE_ANDROID - ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(status_code_) << ". "; - if (!err_description_.empty()) { - ss << err_description_; + ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". "; + if (!data_->err_description.empty()) { + ss << data_->err_description; } ss << "\n"; #endif - if (line_of_code_ > 0 && !file_name_.empty()) { - ss << "Line of code : " << line_of_code_ << "\n"; - ss << "File : " << file_name_ << "\n"; + if (data_->line_of_code > 0 && !data_->file_name.empty()) { + ss << "Line of code : " << data_->line_of_code << "\n"; + ss << "File : " << data_->file_name << "\n"; + } + data_->status_msg = ss.str(); + return StringToChar(data_->status_msg); +} + +bool Status::operator==(const Status &other) const { + if (data_ == nullptr && other.data_ == nullptr) { + return true; + } + + if (data_ == nullptr || other.data_ == nullptr) { + return false; } - status_msg_ = ss.str(); - return status_msg_; + + return data_->status_code == other.data_->status_code; } + +bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; } +bool Status::operator!=(const Status &other) const { return !operator==(other); } +bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); } + +Status::operator bool() const { return (StatusCode() == kSuccess); } +Status::operator int() const { return static_cast(StatusCode()); } + +Status Status::OK() { return StatusCode::kSuccess; } +bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); } +bool Status::IsError() const { return !IsOk(); } } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index f94aa367c6..b28a27e557 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -67,7 +67,7 @@ Model::Model(const std::vector &network, const std::shared_ptr Model::~Model() {} -bool Model::CheckModelSupport(const std::string &device_type, ModelType) { +bool Model::CheckModelSupport(const std::vector &, ModelType) { MS_LOG(ERROR) << "Unsupported feature."; return false; } diff --git a/mindspore/lite/src/cxx_api/serialization.cc b/mindspore/lite/src/cxx_api/serialization.cc index 660cf107ac..ca47d0cc1c 100644 --- a/mindspore/lite/src/cxx_api/serialization.cc +++ b/mindspore/lite/src/cxx_api/serialization.cc @@ -47,7 +47,7 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy return graph; } -Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { +Graph Serialization::LoadModel(const std::vector &file, ModelType model_type) { MS_LOG(ERROR) << "Unsupported Feature."; return Graph(nullptr); } diff --git a/mindspore/lite/src/cxx_api/types.cc b/mindspore/lite/src/cxx_api/types.cc index 876780459b..12f987067a 100644 --- a/mindspore/lite/src/cxx_api/types.cc +++ b/mindspore/lite/src/cxx_api/types.cc @@ -60,16 +60,16 @@ class Buffer::Impl { MSTensor::MSTensor() : impl_(std::make_shared()) {} MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} MSTensor::MSTensor(const std::shared_ptr &impl) : impl_(impl) {} -MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, - size_t data_len) - : impl_(std::make_shared(name, type, shape, data, data_len)) {} +MSTensor::MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) + : impl_(std::make_shared(CharToString(name), type, shape, data, data_len)) {} MSTensor::~MSTensor() = default; bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } -MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector &shape, +MSTensor MSTensor::CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { - auto impl = std::make_shared(name, type, shape, data, data_len); + auto impl = std::make_shared(CharToString(name), type, shape, data, data_len); if (impl == nullptr) { MS_LOG(ERROR) << "Allocate tensor impl failed."; return MSTensor(nullptr); @@ -77,7 +77,7 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con return MSTensor(impl); } -MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector &shape, +MSTensor MSTensor::CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { auto tensor = CreateTensor(name, type, shape, data, data_len); if (tensor == nullptr) { @@ -98,13 +98,12 @@ MSTensor MSTensor::Clone() const { return ret; } -const std::string &MSTensor::Name() const { - static std::string empty = ""; +std::vector MSTensor::CharName() const { if (impl_ == nullptr) { MS_LOG(ERROR) << "Invalid tensor inpmlement."; - return empty; + return std::vector(); } - return impl_->Name(); + return StringToChar(impl_->Name()); } int64_t MSTensor::ElementNum() const { diff --git a/tests/ut/cpp/cxx_api/context_test.cc b/tests/ut/cpp/cxx_api/context_test.cc index 8509f0457e..204a05dbd6 100644 --- a/tests/ut/cpp/cxx_api/context_test.cc +++ b/tests/ut/cpp/cxx_api/context_test.cc @@ -60,12 +60,6 @@ TEST_F(TestCxxApiContext, test_context_ascend310_context_nullptr_FAILED) { EXPECT_ANY_THROW(ModelContext::GetInsertOpConfigPath(nullptr)); } -TEST_F(TestCxxApiContext, test_context_ascend310_context_wrong_type_SUCCESS) { - auto ctx = std::make_shared(); - ctx->params["mindspore.option.op_select_impl_mode"] = 5; - ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), ""); -} - TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) { auto ctx = std::make_shared(); ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");