cxx api refactor: tensor/status/model

pull/11526/head
lixian 4 years ago committed by zhoufeng
parent f1009cb21b
commit 7d2fd6e76c

@ -65,7 +65,7 @@ install(
install( install(
TARGETS mindspore_shared_lib TARGETS mindspore_shared_lib
LIBRARY DESTINATION ${INSTALL_LIB_DIR} DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore COMPONENT mindspore
) )
@ -327,7 +327,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/transforms.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/transforms.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision_lite.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision_lite.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/minddata_eager.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/execute.h
DESTINATION ${INSTALL_BASE_DIR}/include/minddata/dataset/include DESTINATION ${INSTALL_BASE_DIR}/include/minddata/dataset/include
COMPONENT mindspore COMPONENT mindspore
) )

@ -109,6 +109,8 @@ if(PLATFORM_ARM64)
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend* ops*" EXCLUDE)
if(ENABLE_TOOLS) if(ENABLE_TOOLS)
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
endif() endif()
@ -128,6 +130,8 @@ elseif(PLATFORM_ARM32)
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
if(ENABLE_TOOLS) if(ENABLE_TOOLS)
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
endif() endif()
@ -162,6 +166,8 @@ elseif(WIN32)
endif() endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
set(WIN_LIB_DIR_RUN_X86 ${RUNTIME_PKG_NAME}/benchmark) set(WIN_LIB_DIR_RUN_X86 ${RUNTIME_PKG_NAME}/benchmark)
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${WIN_LIB_DIR_RUN_X86} install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${WIN_LIB_DIR_RUN_X86}
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -182,6 +188,8 @@ else()
endif() endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}

@ -24,7 +24,6 @@
#include "include/api/graph.h" #include "include/api/graph.h"
namespace mindspore { namespace mindspore {
namespace api {
class InputAndOutput; class InputAndOutput;
using Input = InputAndOutput; using Input = InputAndOutput;
using Output = InputAndOutput; using Output = InputAndOutput;
@ -35,7 +34,7 @@ class MS_API CellBase {
virtual ~CellBase() = default; virtual ~CellBase() = default;
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; } virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
virtual std::shared_ptr<CellBase> Clone() const = 0; virtual std::shared_ptr<CellBase> Clone() const = 0;
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; } virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; }
std::vector<Output> operator()(const std::vector<Input> &inputs) const; std::vector<Output> operator()(const std::vector<Input> &inputs) const;
}; };
@ -57,16 +56,16 @@ class MS_API ParameterCell final : public Cell<ParameterCell> {
ParameterCell(ParameterCell &&); ParameterCell(ParameterCell &&);
ParameterCell &operator=(ParameterCell &&); ParameterCell &operator=(ParameterCell &&);
explicit ParameterCell(const Tensor &); explicit ParameterCell(const MSTensor &);
ParameterCell &operator=(const Tensor &); ParameterCell &operator=(const MSTensor &);
explicit ParameterCell(Tensor &&); explicit ParameterCell(MSTensor &&);
ParameterCell &operator=(Tensor &&); ParameterCell &operator=(MSTensor &&);
Tensor GetTensor() const { return tensor_; } MSTensor GetTensor() const { return tensor_; }
private: private:
Tensor tensor_; MSTensor tensor_;
}; };
class MS_API OpCellBase : public CellBase { class MS_API OpCellBase : public CellBase {
@ -99,11 +98,9 @@ class MS_API GraphCell final : public Cell<GraphCell> {
explicit GraphCell(const std::shared_ptr<Graph> &); explicit GraphCell(const std::shared_ptr<Graph> &);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; } const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs();
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; std::vector<MSTensor> GetOutputs();
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
private: private:
friend class ModelImpl; friend class ModelImpl;
@ -119,8 +116,8 @@ class MS_API InputAndOutput {
~InputAndOutput() = default; ~InputAndOutput() = default;
// no explicit // no explicit
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit) InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit)
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit)
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index); InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
@ -132,6 +129,5 @@ class MS_API InputAndOutput {
std::vector<InputAndOutput> prev_; std::vector<InputAndOutput> prev_;
int32_t index_; int32_t index_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CELL_H #endif // MINDSPORE_INCLUDE_API_CELL_H

@ -16,26 +16,49 @@
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <map>
#include <any>
#include <string> #include <string>
#include <memory> #include <memory>
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
namespace api { constexpr auto kDeviceTypeAscend310 = "Ascend310";
class MS_API Context { constexpr auto kDeviceTypeAscend910 = "Ascend910";
public:
static Context &Instance(); struct MS_API Context {
const std::string &GetDeviceTarget() const; virtual ~Context() = default;
Context &SetDeviceTarget(const std::string &device_target); std::map<std::string, std::any> params;
uint32_t GetDeviceID() const; };
Context &SetDeviceID(uint32_t device_id);
struct MS_API GlobalContext : public Context {
private: static std::shared_ptr<Context> GetGlobalContext();
Context();
~Context(); static void SetGlobalDeviceTarget(const std::string &device_target);
class ContextImpl; static std::string GetGlobalDeviceTarget();
std::shared_ptr<ContextImpl> impl_;
static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID();
};
struct MS_API ModelContext : public Context {
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static std::string GetInputShape(const std::shared_ptr<Context> &context);
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H #endif // MINDSPORE_INCLUDE_API_CONTEXT_H

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_
#define MINDSPORE_INCLUDE_API_DATA_TYPE_H_
namespace mindspore {
enum class DataType : int {
kTypeUnknown = 0,
kObjectTypeString = 12,
kObjectTypeList = 13,
kObjectTypeTuple = 14,
kObjectTypeTensorType = 17,
kNumberTypeBool = 30,
kNumberTypeInt8 = 32,
kNumberTypeInt16 = 33,
kNumberTypeInt32 = 34,
kNumberTypeInt64 = 35,
kNumberTypeUInt8 = 37,
kNumberTypeUInt16 = 38,
kNumberTypeUInt32 = 39,
kNumberTypeUInt64 = 40,
kNumberTypeFloat16 = 42,
kNumberTypeFloat32 = 43,
kNumberTypeFloat64 = 44,
kNumberTypeEnd = 46,
// add new enum here
kInvalidType = INT32_MAX,
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DATA_TYPE_H_

@ -16,6 +16,7 @@
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H #ifndef MINDSPORE_INCLUDE_API_GRAPH_H
#define MINDSPORE_INCLUDE_API_GRAPH_H #define MINDSPORE_INCLUDE_API_GRAPH_H
#include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map> #include <map>
@ -24,21 +25,21 @@
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
namespace api {
class MS_API Graph { class MS_API Graph {
public: public:
class GraphData; class GraphData;
explicit Graph(const std::shared_ptr<GraphData> &graph_data); explicit Graph(const std::shared_ptr<GraphData> &graph_data);
explicit Graph(std::shared_ptr<GraphData> &&graph_data); explicit Graph(std::shared_ptr<GraphData> &&graph_data);
explicit Graph(std::nullptr_t);
~Graph(); ~Graph();
enum ModelType ModelType() const; enum ModelType ModelType() const;
bool operator==(std::nullptr_t) const;
private: private:
friend class GraphCell; friend class GraphCell;
friend class ModelImpl; friend class ModelImpl;
std::shared_ptr<GraphData> graph_data_; std::shared_ptr<GraphData> graph_data_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_GRAPH_H #endif // MINDSPORE_INCLUDE_API_GRAPH_H

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#include <string>
#include <memory>
#include <map>
#include <any>
#include "include/api/types.h"
namespace mindspore {
namespace lite {
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
typedef enum : uint32_t {
NO_BIND = 0, /**< no bind */
HIGHER_CPU = 1, /**< bind higher cpu first */
MID_CPU = 2 /**< bind middle cpu first */
} CpuBindMode;
class Allocator;
} // namespace lite
struct MS_API Context {
public:
static void Clear(const std::shared_ptr<Context> &contxet);
static void SetAsDefault(const std::shared_ptr<Context> &contxet);
static void SetVendorName(const std::shared_ptr<Context> &contxet, const std::string &name);
static std::string GetVendorName(const std::shared_ptr<Context> &contxet);
static void SetThreadNum(const std::shared_ptr<Context> &contxet, int num);
static int GetThreadNum(const std::shared_ptr<Context> &contxet);
static void SetAllocator(const std::shared_ptr<Context> &contxet, std::shared_ptr<lite::Allocator> alloc);
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &contxet);
static void ConfigCPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfCPUEnabled(const std::shared_ptr<Context> &contxet);
static void ConfigCPUFp16(const std::shared_ptr<Context> &contxet, bool config);
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &contxet);
static void SetCPUBindMode(const std::shared_ptr<Context> &contxet, lite::CpuBindMode mode);
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &contxet);
static void ConfigGPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfGPUEnabled(const std::shared_ptr<Context> &contxet);
static void ConfigGPUFp16(const std::shared_ptr<Context> &contxet, bool config);
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &contxet);
static void ConfigNPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfNPUEnabled(const std::shared_ptr<Context> &contxet);
static void SetNPUFrequency(const std::shared_ptr<Context> &contxet, int freq);
static int GetNPUFrequency(const std::shared_ptr<Context> &contxet);
private:
std::map<std::string, std::any> context_;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H

@ -20,41 +20,36 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility>
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/graph.h" #include "include/api/graph.h"
#include "include/api/cell.h" #include "include/api/cell.h"
namespace mindspore { namespace mindspore {
namespace api {
class ModelImpl; class ModelImpl;
// todo: minddata c++ interface struct Context;
class DataSet {};
class MS_API Model { class MS_API Model {
public: public:
explicit Model(const std::vector<Output> &network); explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
explicit Model(const GraphCell &graph); explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
~Model(); ~Model();
Model(const Model &) = delete; Model(const Model &) = delete;
void operator=(const Model &) = delete; void operator=(const Model &) = delete;
Status Build(const std::map<std::string, std::string> &options); Status Build();
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs); Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs();
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; std::vector<MSTensor> GetOutputs();
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
static bool CheckModelSupport(const std::string &device_type, ModelType model_type); static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
private: private:
std::shared_ptr<ModelImpl> impl_; std::shared_ptr<ModelImpl> impl_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H #endif // MINDSPORE_INCLUDE_API_MODEL_H

@ -25,7 +25,6 @@
#include "include/api/cell.h" #include "include/api/cell.h"
namespace mindspore { namespace mindspore {
namespace api {
struct MS_API Conv2D : public OpCell<Conv2D> { struct MS_API Conv2D : public OpCell<Conv2D> {
Conv2D() : OpCell("Conv2D") {} Conv2D() : OpCell("Conv2D") {}
~Conv2D() override = default; ~Conv2D() override = default;
@ -45,6 +44,5 @@ struct MS_API Conv2D : public OpCell<Conv2D> {
std::vector<int> dilation = {1, 1, 1, 1}; std::vector<int> dilation = {1, 1, 1, 1};
int group = 1; int group = 1;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H #endif // MINDSPORE_INCLUDE_API_OPS_OPS_H

@ -26,15 +26,14 @@
#include "include/api/graph.h" #include "include/api/graph.h"
namespace mindspore { namespace mindspore {
namespace api {
class MS_API Serialization { class MS_API Serialization {
public: public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
static Graph LoadModel(const std::string &file, ModelType model_type); static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters); static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model); static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

@ -17,37 +17,129 @@
#define MINDSPORE_INCLUDE_API_STATUS_H #define MINDSPORE_INCLUDE_API_STATUS_H
#include <string> #include <string>
#include <ostream>
#include <climits>
namespace mindspore { namespace mindspore {
namespace api { enum CompCode : uint32_t {
enum StatusCode { kCore = 0x00000000u,
SUCCESS = 0, kMD = 0x10000000u,
FAILED, kME = 0x20000000u,
INVALID_INPUTS, kMC = 0x30000000u,
// insert new status code here kLite = 0xF0000000u,
UNKNOWN = 0xFFFFFFFF };
enum StatusCode : uint32_t {
kSuccess = 0,
// Core
kCoreFailed = kCore | 0x1,
// MD
kMDOutOfMemory = kMD | 1,
kMDShapeMisMatch = kMD | 2,
kMDInterrupted = kMD | 3,
kMDNoSpace = kMD | 4,
kMDPyFuncException = kMD | 5,
kMDDuplicateKey = kMD | 6,
kMDPythonInterpreterFailure = kMD | 7,
kMDTDTPushFailure = kMD | 8,
kMDFileNotExist = kMD | 9,
kMDProfilingError = kMD | 10,
kMDBoundingBoxOutOfBounds = kMD | 11,
kMDBoundingBoxInvalidShape = kMD | 12,
kMDSyntaxError = kMD | 13,
kMDTimeOut = kMD | 14,
kMDBuddySpaceFull = kMD | 15,
kMDNetWorkError = kMD | 16,
kMDNotImplementedYet = kMD | 17,
// Make this error code the last one. Add new error code above it.
kMDUnexpectedError = kMD | 127,
// ME
kMEFailed = kME | 0x1,
kMEInvalidInput = kME | 0x2,
// MC
kMCFailed = kMC | 0x1,
kMCDeviceError = kMC | 0x2,
kMCInvalidInput = kMC | 0x3,
kMCInvalidArgs = kMC | 0x4,
// Lite // Common error code, range: [-1, -100
kLiteError = kLite | (0x0FFFFFFF & -1), /**< Common error code. */
kLiteNullptr = kLite | (0x0FFFFFFF & -2), /**< NULL pointer returned.*/
kLiteParamInvalid = kLite | (0x0FFFFFFF & -3), /**< Invalid parameter.*/
kLiteNoChange = kLite | (0x0FFFFFFF & -4), /**< No change. */
kLiteSuccessExit = kLite | (0x0FFFFFFF & -5), /**< No error but exit. */
kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */
kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
// Executor error code, range: [-100,-200)
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
// Graph error code, range: [-200,-300)
kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
// Node error code, range: [-300,-400)
kLiteNotFindOp = kLite | (0x0FFFFFFF & -300), /**< Failed to find operator. */
kLiteInvalidOpName = kLite | (0x0FFFFFFF & -301), /**< Invalid operator name. */
kLiteInvalidOpAttr = kLite | (0x0FFFFFFF & -302), /**< Invalid operator attr. */
kLiteOpExecuteFailure = kLite | (0x0FFFFFFF & -303), /**< Failed to execution operator. */
// Tensor error code, range: [-400,-500)
kLiteFormatError = kLite | (0x0FFFFFFF & -400), /**< Failed to checking tensor format. */
// InferShape error code, range: [-500,-600)
kLiteInferError = kLite | (0x0FFFFFFF & -500), /**< Failed to infer shape. */
kLiteInferInvalid = kLite | (0x0FFFFFFF & -501), /**< Invalid infer shape before runtime. */
// User input param error code, range: [-600, 700)
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
}; };
class Status { class Status {
public: public:
Status() : status_code_(FAILED) {} Status() : status_code_(kSuccess), line_of_code_(-1) {}
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg) {} : status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {}
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
~Status() = default; ~Status() = default;
bool IsSuccess() const { return status_code_ == SUCCESS; }
enum StatusCode StatusCode() const { return status_code_; } enum StatusCode StatusCode() const { return status_code_; }
std::string StatusMessage() const { return status_msg_; } const std::string &ToString() const { return status_msg_; }
int GetLineOfCode() const { return line_of_code_; }
const std::string &GetErrDescription() const { return status_msg_; }
const std::string &SetErrDescription(const std::string &err_description);
friend std::ostream &operator<<(std::ostream &os, const Status &s);
bool operator==(const Status &other) const { return status_code_ == other.status_code_; } bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
operator bool() const = delete;
explicit operator bool() const { return (status_code_ == kSuccess); }
explicit operator int() const { return static_cast<int>(status_code_); }
static Status OK() { return Status(StatusCode::kSuccess); }
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool IsError() const { return !IsOk(); }
static std::string CodeAsString(enum StatusCode c);
private: private:
enum StatusCode status_code_; enum StatusCode status_code_;
std::string status_msg_; std::string status_msg_;
int line_of_code_;
std::string file_name_;
std::string err_description_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_STATUS_H #endif // MINDSPORE_INCLUDE_API_STATUS_H

@ -16,15 +16,20 @@
#ifndef MINDSPORE_INCLUDE_API_TYPES_H #ifndef MINDSPORE_INCLUDE_API_TYPES_H
#define MINDSPORE_INCLUDE_API_TYPES_H #define MINDSPORE_INCLUDE_API_TYPES_H
#include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "include/api/data_type.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default"))) #define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
namespace api { enum ModelType : uint32_t {
enum ModelType {
kMindIR = 0, kMindIR = 0,
kAIR = 1, kAIR = 1,
kOM = 2, kOM = 2,
@ -33,52 +38,38 @@ enum ModelType {
kUnknownType = 0xFFFFFFFF kUnknownType = 0xFFFFFFFF
}; };
enum DataType { class MS_API MSTensor {
kMsUnknown = 0,
kMsBool = 1,
kMsInt8 = 2,
kMsInt16 = 3,
kMsInt32 = 4,
kMsInt64 = 5,
kMsUint8 = 6,
kMsUint16 = 7,
kMsUint32 = 8,
kMsUint64 = 9,
kMsFloat16 = 10,
kMsFloat32 = 11,
kMsFloat64 = 12,
// insert new data type here
kInvalidDataType = 0xFFFFFFFF
};
class MS_API Tensor {
public: public:
Tensor(); class Impl;
Tensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, size_t data_len);
~Tensor();
const std::string &Name() const; static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
void SetName(const std::string &name); const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
api::DataType DataType() const; MSTensor();
void SetDataType(api::DataType type); explicit MSTensor(const std::shared_ptr<Impl> &impl);
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
~MSTensor();
const std::string &Name() const;
enum DataType DataType() const;
const std::vector<int64_t> &Shape() const; const std::vector<int64_t> &Shape() const;
void SetShape(const std::vector<int64_t> &shape); int64_t ElementNum() const;
const void *Data() const; std::shared_ptr<const void> Data() const;
void *MutableData(); void *MutableData();
size_t DataSize() const; size_t DataSize() const;
bool ResizeData(size_t data_len); bool IsDevice() const;
bool SetData(const void *data, size_t data_len);
int64_t ElementNum() const; MSTensor Clone() const;
static int GetTypeSize(api::DataType type); bool operator==(std::nullptr_t) const;
Tensor Clone() const;
private: private:
class Impl; friend class ModelImpl;
explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
@ -101,21 +92,5 @@ class MS_API Buffer {
class Impl; class Impl;
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
extern MS_API const char *kDeviceTypeAscend310;
extern MS_API const char *kDeviceTypeAscend910;
extern MS_API const char *kDeviceTypeGpu;
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
// "high_precision" or "high_performance", default as "high_performance"
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H #endif // MINDSPORE_INCLUDE_API_TYPES_H

@ -23,7 +23,7 @@ if(ENABLE_D)
endif() endif()
if(ENABLE_GPU) if(ENABLE_GPU)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc") file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/gpu/*.cc")
endif() endif()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
@ -45,8 +45,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf) -Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf)
else() else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} if(ENABLE_D OR ENABLE_ACL)
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
mindspore proto_input mindspore_gvar mindspore::protobuf)
endif()
endif() endif()
if(ENABLE_CPU) if(ENABLE_CPU)

@ -18,7 +18,7 @@
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#include "cxx_api/graph/graph_impl.h" #include "cxx_api/graph/graph_impl.h"
namespace mindspore::api { namespace mindspore {
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); } std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {} ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
@ -40,23 +40,23 @@ ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
return *this; return *this;
} }
ParameterCell::ParameterCell(const Tensor &tensor) : tensor_(tensor.Clone()) {} ParameterCell::ParameterCell(const MSTensor &tensor) : tensor_(tensor.Clone()) {}
ParameterCell &ParameterCell::operator=(const Tensor &tensor) { ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
tensor_ = tensor.Clone(); tensor_ = tensor.Clone();
return *this; return *this;
} }
ParameterCell::ParameterCell(Tensor &&tensor) : tensor_(tensor) {} ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
ParameterCell &ParameterCell::operator=(Tensor &&tensor) { ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
tensor_ = tensor; tensor_ = tensor;
return *this; return *this;
} }
GraphCell::GraphCell(const Graph &graph) GraphCell::GraphCell(const Graph &graph)
: graph_(std::make_shared<Graph>(graph)), : graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_); executor_->SetGraph(graph_);
@ -64,7 +64,7 @@ GraphCell::GraphCell(const Graph &graph)
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
: graph_(graph), : graph_(graph),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_); executor_->SetGraph(graph_);
@ -72,13 +72,13 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
GraphCell::GraphCell(Graph &&graph) GraphCell::GraphCell(Graph &&graph)
: graph_(std::make_shared<Graph>(graph)), : graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_); executor_->SetGraph(graph_);
} }
Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->Run(inputs, outputs); return executor_->Run(inputs, outputs);
} }
@ -88,25 +88,24 @@ Status GraphCell::Load() {
return executor_->Load(); return executor_->Load();
} }
Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GraphCell::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes); return executor_->GetInputs();
} }
Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GraphCell::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes); return executor_->GetOutputs();
} }
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const Tensor &tensor) InputAndOutput::InputAndOutput(const MSTensor &tensor)
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {} : cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(Tensor &&tensor) : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {} InputAndOutput::InputAndOutput(MSTensor &&tensor)
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev, InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
int32_t index) int32_t index)
: cell_(cell), prev_(prev), index_(index) {} : cell_(cell), prev_(prev), index_(index) {}
} // namespace mindspore::api } // namespace mindspore

@ -16,49 +16,119 @@
#include "include/api/context.h" #include "include/api/context.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore::api { constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
class Context::ContextImpl { constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
public: constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
ContextImpl() : device_target_("NotSet"), device_id_(0) {} constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
~ContextImpl() = default; constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
const std::string &GetDeviceTarget() const { return device_target_; } // Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; } constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
uint32_t GetDeviceID() const { return device_id_; } constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; } // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
private: namespace mindspore {
std::string device_target_; template <class T>
uint32_t device_id_; static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
}; auto iter = context->params.find(key);
if (iter == context->params.end()) {
return T();
}
const std::any &value = iter->second;
if (value.type() != typeid(T)) {
return T();
}
Context &Context::Instance() { return std::any_cast<T>(value);
static Context context;
return context;
} }
const std::string &Context::GetDeviceTarget() const { std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
MS_EXCEPTION_IF_NULL(impl_); static std::shared_ptr<Context> g_context = std::make_shared<Context>();
return impl_->GetDeviceTarget(); return g_context;
} }
Context &Context::SetDeviceTarget(const std::string &device_target) { void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
MS_EXCEPTION_IF_NULL(impl_); auto global_context = GetGlobalContext();
impl_->SetDeviceTarget(device_target); MS_EXCEPTION_IF_NULL(global_context);
return *this; global_context->params[kGlobalContextDeviceTarget] = device_target;
} }
uint32_t Context::GetDeviceID() const { std::string GlobalContext::GetGlobalDeviceTarget() {
MS_EXCEPTION_IF_NULL(impl_); auto global_context = GetGlobalContext();
return impl_->GetDeviceID(); MS_EXCEPTION_IF_NULL(global_context);
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
} }
Context &Context::SetDeviceID(uint32_t device_id) { void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
MS_EXCEPTION_IF_NULL(impl_); auto global_context = GetGlobalContext();
impl_->SetDeviceID(device_id); MS_EXCEPTION_IF_NULL(global_context);
return *this; global_context->params[kGlobalContextDeviceID] = device_id;
} }
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); } uint32_t GlobalContext::GetGlobalDeviceID() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
}
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
}
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputFormat] = format;
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputFormat);
}
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputShape] = shape;
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputShape);
}
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOutputType] = output_type;
}
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<enum DataType>(context, kModelOptionOutputType);
}
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionPrecisionMode] = precision_mode;
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionPrecisionMode);
}
Context::~Context() {} void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
} // namespace mindspore::api const std::string &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
}
} // namespace mindspore

@ -23,7 +23,7 @@
#include <utility> #include <utility>
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore::api { namespace mindspore {
template <class T> template <class T>
class Factory { class Factory {
using U = std::function<std::shared_ptr<T>()>; using U = std::function<std::shared_ptr<T>()>;
@ -79,5 +79,5 @@ class Registrar {
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ #define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); }); #DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H

@ -17,8 +17,8 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "acl/acl.h" #include "acl/acl.h"
namespace mindspore::api { namespace mindspore {
std::weak_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_; std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
std::mutex AclEnvGuard::global_acl_env_mutex_; std::mutex AclEnvGuard::global_acl_env_mutex_;
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
@ -42,7 +42,7 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
std::shared_ptr<AclEnvGuard> acl_env; std::shared_ptr<AclEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_acl_env_mutex_); std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
acl_env = global_acl_env_.lock(); acl_env = global_acl_env_;
if (acl_env != nullptr) { if (acl_env != nullptr) {
MS_LOG(INFO) << "Acl has been initialized, skip."; MS_LOG(INFO) << "Acl has been initialized, skip.";
} else { } else {
@ -57,4 +57,4 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
} }
return acl_env; return acl_env;
} }
} // namespace mindspore::api } // namespace mindspore

@ -20,7 +20,7 @@
#include <mutex> #include <mutex>
#include "acl/acl_base.h" #include "acl/acl_base.h"
namespace mindspore::api { namespace mindspore {
class __attribute__((visibility("default"))) AclEnvGuard { class __attribute__((visibility("default"))) AclEnvGuard {
public: public:
explicit AclEnvGuard(std::string_view cfg_file); explicit AclEnvGuard(std::string_view cfg_file);
@ -29,10 +29,10 @@ class __attribute__((visibility("default"))) AclEnvGuard {
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file); static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
private: private:
static std::weak_ptr<AclEnvGuard> global_acl_env_; static std::shared_ptr<AclEnvGuard> global_acl_env_;
static std::mutex global_acl_env_mutex_; static std::mutex global_acl_env_mutex_;
aclError errno_; aclError errno_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H

@ -16,53 +16,50 @@
#include "cxx_api/graph/acl/acl_graph_impl.h" #include "cxx_api/graph/acl/acl_graph_impl.h"
#include "include/api/context.h" #include "include/api/context.h"
#include "cxx_api/model/acl/model_converter.h" #include "cxx_api/model/acl/model_converter.h"
#include "cxx_api/python_utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore::api { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
AclGraphImpl::AclGraphImpl() AclGraphImpl::AclGraphImpl()
: init_flag_(false), : init_flag_(false),
load_flag_(false), load_flag_(false),
device_type_("AscendCL"), device_type_("AscendCL"),
device_id_(Context::Instance().GetDeviceID()), device_id_(GlobalContext::GetGlobalDeviceID()),
context_(nullptr), context_(nullptr),
acl_env_(nullptr) {} acl_env_(nullptr) {}
AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); } AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
Status AclGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed."; MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED; return ret;
} }
return model_process_.PredictFromHost(inputs, outputs); return model_process_.PredictFromHost(inputs, outputs);
} }
Status AclGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AclGraphImpl::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed."; MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED; return {};
} }
return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes); return model_process_.GetInputs();
} }
Status AclGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AclGraphImpl::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed."; MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED; return {};
} }
return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes); return model_process_.GetOutputs();
} }
Status AclGraphImpl::LoadAclModel(Buffer om_data) { Status AclGraphImpl::LoadAclModel(Buffer om_data) {
@ -72,44 +69,44 @@ Status AclGraphImpl::LoadAclModel(Buffer om_data) {
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
if (acl_ret != ACL_ERROR_NONE) { if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
return FAILED; return kMCDeviceError;
} }
// acl init model resource // acl init model resource
model_process_.set_model_id(acl_model_id); model_process_.set_model_id(acl_model_id);
Status ret = model_process_.PreInitModelResource(); Status ret = model_process_.PreInitModelResource();
if (ret != SUCCESS) { if (ret != kSuccess) {
(void)aclmdlUnload(acl_model_id); (void)aclmdlUnload(acl_model_id);
MS_LOG(ERROR) << "Pre init model resource failed."; MS_LOG(ERROR) << "Pre init model resource failed.";
return FAILED; return ret;
} }
MS_LOG(INFO) << "Load acl model success."; MS_LOG(INFO) << "Load acl model success.";
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::InitEnv() { Status AclGraphImpl::InitEnv() {
if (init_flag_) { if (init_flag_) {
return SUCCESS; return kSuccess;
} }
acl_env_ = AclEnvGuard::GetAclEnv(""); acl_env_ = AclEnvGuard::GetAclEnv("");
if (acl_env_ == nullptr) { if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed."; MS_LOG(ERROR) << "Acl init failed.";
return FAILED; return kMCDeviceError;
} }
aclError ret = aclrtSetDevice(device_id_); aclError ret = aclrtSetDevice(device_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
return FAILED; return kMCDeviceError;
} }
MS_LOG(INFO) << "Open device " << device_id_ << " success"; MS_LOG(INFO) << "Open device " << device_id_ << " success";
ret = aclrtCreateContext(&context_, device_id_); ret = aclrtCreateContext(&context_, device_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl create context failed"; MS_LOG(ERROR) << "Acl create context failed";
return FAILED; return kMCDeviceError;
} }
MS_LOG(INFO) << "Create context success"; MS_LOG(INFO) << "Create context success";
@ -117,7 +114,7 @@ Status AclGraphImpl::InitEnv() {
ret = aclrtGetRunMode(&run_mode); ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl get run mode failed"; MS_LOG(ERROR) << "Acl get run mode failed";
return FAILED; return kMCDeviceError;
} }
bool is_device = (run_mode == ACL_DEVICE); bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device); model_process_.SetIsDevice(is_device);
@ -125,24 +122,24 @@ Status AclGraphImpl::InitEnv() {
MS_LOG(INFO) << "Init acl success, device id " << device_id_; MS_LOG(INFO) << "Init acl success, device id " << device_id_;
init_flag_ = true; init_flag_ = true;
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::FinalizeEnv() { Status AclGraphImpl::FinalizeEnv() {
if (!init_flag_) { if (!init_flag_) {
return SUCCESS; return kSuccess;
} }
aclError rt_ret = aclrtSetCurrentContext(context_); aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) { if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed"; MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED; return kMCDeviceError;
} }
Status ret = model_process_.UnLoad(); Status ret = model_process_.UnLoad();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Unload model inner failed."; MS_LOG(ERROR) << "Unload model inner failed.";
return FAILED; return ret;
} }
if (context_ != nullptr) { if (context_ != nullptr) {
@ -161,16 +158,16 @@ Status AclGraphImpl::FinalizeEnv() {
MS_LOG(INFO) << "End to reset device " << device_id_; MS_LOG(INFO) << "End to reset device " << device_id_;
init_flag_ = false; init_flag_ = false;
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::Load() { Status AclGraphImpl::Load() {
// check graph type // check graph type
if (graph_->ModelType() != ModelType::kOM) { if (graph_->ModelType() != ModelType::kOM) {
Status ret = ConvertToOM(); Status ret = ConvertToOM();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Load Failed."; MS_LOG(ERROR) << "Load Failed.";
return FAILED; return ret;
} }
} }
@ -180,15 +177,15 @@ Status AclGraphImpl::Load() {
// init // init
Status ret = InitEnv(); Status ret = InitEnv();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed."; MS_LOG(ERROR) << "InitEnv failed.";
return FAILED; return ret;
} }
// load model // load model
if (!load_flag_) { if (!load_flag_) {
ret = LoadAclModel(om_data); ret = LoadAclModel(om_data);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Load acl model failed."; MS_LOG(ERROR) << "Load acl model failed.";
return ret; return ret;
} }
@ -198,24 +195,24 @@ Status AclGraphImpl::Load() {
aclError rt_ret = aclrtSetCurrentContext(context_); aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) { if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed"; MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED; return kMCDeviceError;
} }
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::ConvertToOM() { Status AclGraphImpl::ConvertToOM() {
MS_LOG(INFO) << "Start convert to om model."; MS_LOG(INFO) << "Start convert to om model.";
if (graph_ == nullptr) { if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid graph_ is null."; MS_LOG(ERROR) << "Invalid graph_ is null.";
return FAILED; return kMCFailed;
} }
auto &graph_data = GraphImpl::MutableGraphData(); auto &graph_data = GraphImpl::MutableGraphData();
MS_EXCEPTION_IF_NULL(graph_data); MS_EXCEPTION_IF_NULL(graph_data);
if (graph_->ModelType() == ModelType::kOM) { if (graph_->ModelType() == ModelType::kOM) {
MS_LOG(INFO) << "This model has been built, skip."; MS_LOG(INFO) << "This model has been built, skip.";
return SUCCESS; return kSuccess;
} else if (graph_->ModelType() == ModelType::kMindIR) { } else if (graph_->ModelType() == ModelType::kMindIR) {
auto func_graph = graph_data->GetFuncGraph(); auto func_graph = graph_data->GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -223,13 +220,13 @@ Status AclGraphImpl::ConvertToOM() {
Buffer om_data = model_converter.LoadMindIR(func_graph); Buffer om_data = model_converter.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) { if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Convert MindIR to OM failed."; MS_LOG(ERROR) << "Convert MindIR to OM failed.";
return FAILED; return kMCFailed;
} }
graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM); graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM);
MS_LOG(INFO) << "Convert MindIR to OM success."; MS_LOG(INFO) << "Convert MindIR to OM success.";
return SUCCESS; return kSuccess;
} }
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
return FAILED; return kMCFailed;
} }
} // namespace mindspore::api } // namespace mindspore

@ -27,18 +27,16 @@
#include "cxx_api/graph/graph_impl.h" #include "cxx_api/graph/graph_impl.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
namespace mindspore::api { namespace mindspore {
class AclGraphImpl : public GraphCell::GraphImpl { class AclGraphImpl : public GraphCell::GraphImpl {
public: public:
AclGraphImpl(); AclGraphImpl();
~AclGraphImpl() override; ~AclGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override; Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private: private:
Status ConvertToOM(); Status ConvertToOM();
@ -56,5 +54,5 @@ class AclGraphImpl : public GraphCell::GraphImpl {
ModelProcess model_process_; ModelProcess model_process_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H

File diff suppressed because it is too large Load Diff

@ -25,7 +25,7 @@
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore::api { namespace mindspore {
struct AclTensorInfo { struct AclTensorInfo {
void *device_data; void *device_data;
size_t buffer_size; size_t buffer_size;
@ -45,14 +45,12 @@ class ModelProcess {
input_infos_(), input_infos_(),
output_infos_() {} output_infos_() {}
~ModelProcess() {} ~ModelProcess() {}
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
Status UnLoad(); Status UnLoad();
Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); Status PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status PreInitModelResource(); Status PreInitModelResource();
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs();
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; std::vector<MSTensor> GetOutputs();
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
// override this method to avoid request/reply data copy // override this method to avoid request/reply data copy
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
@ -62,8 +60,9 @@ class ModelProcess {
private: private:
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
Status CheckAndInitInput(const std::vector<Buffer> &inputs); Status CheckAndInitInput(const std::vector<MSTensor> &inputs);
Status BuildOutputs(std::vector<Buffer> *outputs); Status ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<MSTensor> *tensor_list);
Status BuildOutputs(std::vector<MSTensor> *outputs);
Status InitInputsBuffer(); Status InitInputsBuffer();
Status InitOutputsBuffer(); Status InitOutputsBuffer();
@ -80,7 +79,9 @@ class ModelProcess {
aclmdlDataset *outputs_; aclmdlDataset *outputs_;
std::vector<AclTensorInfo> input_infos_; std::vector<AclTensorInfo> input_infos_;
std::vector<AclTensorInfo> output_infos_; std::vector<AclTensorInfo> output_infos_;
std::vector<MSTensor> input_tensors_;
std::vector<MSTensor> output_tensors_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H #endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H

File diff suppressed because it is too large Load Diff

@ -28,40 +28,56 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "cxx_api/model/model_impl.h" #include "cxx_api/model/model_impl.h"
#include "runtime/context.h" #include "runtime/context.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api { namespace mindspore {
class AscendGraphImpl : public GraphCell::GraphImpl { class AscendGraphImpl : public GraphCell::GraphImpl {
public: public:
AscendGraphImpl(); AscendGraphImpl();
~AscendGraphImpl() override; ~AscendGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override; Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private: private:
class MsEnvGuard;
Status InitEnv(); Status InitEnv();
Status FinalizeEnv();
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr); Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const; Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_; std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_; uint32_t graph_id_;
std::string device_type_; std::string device_type_;
uint32_t device_id_; uint32_t device_id_;
rtContext_t context_; rtContext_t context_;
std::vector<tensor::TensorPtr> inputs_; std::vector<tensor::TensorPtr> inputs_info_;
std::vector<tensor::TensorPtr> outputs_; std::vector<tensor::TensorPtr> outputs_info_;
std::vector<tensor::TensorPtr> last_inputs_;
std::vector<tensor::TensorPtr> last_outputs_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
bool init_flag_;
bool load_flag_; bool load_flag_;
std::shared_ptr<MsEnvGuard> env_guard_;
};
class AscendGraphImpl::MsEnvGuard {
public:
explicit MsEnvGuard(uint32_t device_id);
~MsEnvGuard();
Status GetErrno() const { return errno_; }
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
private:
static std::weak_ptr<MsEnvGuard> global_ms_env_;
static std::mutex global_ms_env_mutex_;
Status errno_;
uint32_t device_id_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save