fix for code naming specification

pull/4276/head
hangq 5 years ago
parent e73e9a9aee
commit 7556389aef

@ -22,41 +22,41 @@
#include "include/ms_tensor.h"
namespace mindspore::lite {
/// \brief Allocator defined by MindSpore Lite.
/// \brief Allocator defined a memory pool for malloc memory and free memory dynamically.
///
/// \note List public class and interface for reference.
class Allocator;
/// \brief CpuBindMode defined by MindSpore Lite.
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
enum CpuBindMode {
MID_CPU = -1, /**< bind mid cpu first */
MID_CPU = -1, /**< bind middle cpu first */
HIGHER_CPU = 1, /**< bind higher cpu first */
NO_BIND = 0 /**< no bind */
};
/// \brief DeviceType defined by MindSpore Lite.
/// \brief DeviceType defined for holding user's preferred backend.
typedef enum {
DT_CPU, /**< CPU device type */
DT_GPU, /**< GPU device type */
DT_NPU /**< NPU device type */
} DeviceType;
/// \brief DeviceContext defined by MindSpore Lite.
/// \brief DeviceContext defined for holding DeviceType.
typedef struct {
DeviceType type; /**< device type */
} DeviceContext;
/// \brief Context defined by MindSpore Lite
/// \brief Context defined for holding some environment for runtime.
class MS_API Context {
public:
/// \brief Constructor of MindSpore Lite context using default value for parameters.
/// \brief Constructor of MindSpore Lite Context using default value for parameters.
///
/// \return Instance of MindSpore Lite Context.
Context();
/// \brief Constructor of MindSpore Lite Context using input value for parameters.
///
/// \param[in] thread_num Define the threadNum during the runtime.
/// \param[in] thread_num Define the work thread number during the runtime.
/// \param[in] allocator Define the allocator for malloc.
/// \param[in] device_ctx Define device information during the runtime.
Context(int thread_num, std::shared_ptr<Allocator> allocator, DeviceContext device_ctx);

@ -19,6 +19,7 @@
namespace mindspore {
namespace lite {
/// \brief STATUS defined error code in MindSpore Lite.
using STATUS = int;
/* Success */
@ -33,8 +34,8 @@ constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */
constexpr int RET_MEMORY_FAILED = -6; /**< Create memory failed. */
/* Executor error code, range: [-101,-200] */
constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to checking range. */
constexpr int RET_INPUT_TENSOR_ERROR = -102; /**< Failed to checking input tensor. */
constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to check range. */
constexpr int RET_INPUT_TENSOR_ERROR = -102; /**< Failed to check input tensor. */
constexpr int RET_REENTRANT_ERROR = -103; /**< Exist executor running. */
/* Graph error code, range: [-201,-300] */

@ -27,15 +27,17 @@
namespace mindspore {
namespace session {
/// \brief CallBackParam defined input arguments for callBack function.
struct CallBackParam {
std::string name_callback_param;
std::string type_callback_param;
std::string name_callback_param; /**< node name argument */
std::string type_callback_param; /**< node type argument */
};
/// \brief KernelCallBack defined the function pointer for callBack.
using KernelCallBack = std::function<bool(std::vector<tensor::MSTensor *> inputs,
std::vector<tensor::MSTensor *> outputs, const CallBackParam &opInfo)>;
/// \brief LiteSession defined by MindSpore Lite.
/// \brief LiteSession defined session in MindSpore Lite for compiling Model and forwarding model.
class MS_API LiteSession {
public:
/// \brief Static method to create a LiteSession pointer.
@ -48,52 +50,52 @@ class MS_API LiteSession {
/// \brief Destructor of MindSpore Lite LiteSession.
virtual ~LiteSession() = default;
/// \brief Try to bind or unbind threads in the thread pool to specified cpu core.
/// \brief Try to bind or unbind threads in the thread pool to the specified cpu core.
///
/// \param[in] if_bind Define weather to bind or unbind threads.
/// \param[in] if_bind Define whether to bind or unbind threads.
virtual void BindThread(bool if_bind) = 0;
/// \brief Compile MindSpore lite model.
/// \brief Compile MindSpore Lite model.
///
/// \note CompileGraph should called before RunGraph.
///
/// \param[in] model Define the model to be compiled.
///
/// \return ErrorCode of compile graph.
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h.
virtual int CompileGraph(lite::Model *model) = 0;
/// \brief Get input MindSpore Lite MSTensors of model.
///
/// \return A vector of MindSpore Lite MSTensor.
/// \return The vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetInputs() const = 0;
/// \brief Get input MindSpore Lite MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \return A vector of MindSpore Lite MSTensor.
/// \return The vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetInputsByName(const std::string &node_name) const = 0;
/// \brief Run session with callback.
///
/// \param[in] before Define a call_back_function called before running each node
/// \param[in] after Define a call_back_function called after running each node
/// \param[in] before Define a call_back_function called before running each node.
/// \param[in] after Define a call_back_function called after running each node.
///
/// \note RunGraph should called after CompileGraph.
///
/// \return ErrorCode of run graph.
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
/// \brief Get output MindSpore Lite MSTensors of model.
///
/// \return A map of output node name and MindSpore Lite MSTensor.
/// \return The map of output node name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \return A vector of MindSpore Lite MSTensor.
/// \return The vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const = 0;
};
} // namespace session

@ -25,24 +25,24 @@
namespace mindspore {
#define MS_API __attribute__((visibility("default")))
/// \brief ModelImpl defined by MindSpore Lite.
/// \brief ModelImpl defined the implement class of Model in MindSpore Lite.
///
/// \note List public class and interface for reference.
class ModelImpl;
namespace lite {
/// \brief Primitive defined by MindSpore Lite.
/// \brief Primitive defined as prototype of operator.
///
/// \note List public class and interface for reference.
class Primitive;
/// \brief Model defined by MindSpore Lite.
/// \brief Model defined model in MindSpore Lite for managing graph.
class MS_API Model {
public:
/// \brief Static method to create a Model pointer.
///
/// \param[in] model_buf Define the buffer read from a model file.
/// \param[in] size Define bytes numbers of model buffer.
/// \param[in] size Define bytes number of model buffer.
///
/// \return Pointer of MindSpore Lite Model.
static Model *Import(const char *model_buf, size_t size);
@ -59,17 +59,17 @@ class MS_API Model {
///
/// \param[in] name Define name of primitive to be returned.
///
/// \return A pointer of MindSpore Lite Primitive.
/// \return the pointer of MindSpore Lite Primitive.
lite::Primitive *GetOp(const std::string &name) const;
/// \brief Get MindSpore Lite MetaGraph.
/// \brief Get graph defined in flatbuffers.
///
/// \return A pointer of MindSpore Lite MetaGraph.
/// \return the pointer of graph defined in flatbuffers.
const schema::MetaGraph *GetMetaGraph() const;
/// \brief Get MindSpore Lite ModelImpl.
///
/// \return A pointer of MindSpore Lite ModelImpl.
/// \return the pointer of MindSpore Lite ModelImpl.
ModelImpl *model_impl();
/// \brief Free MetaGraph in MindSpore Lite Model.
@ -84,7 +84,7 @@ class MS_API ModelBuilder {
public:
/// \brief OutEdge defined by MindSpore Lite.
struct OutEdge {
std::string nodeId; /**< Id of a node linked by this edge */
std::string nodeId; /**< ID of a node linked by this edge */
size_t outEdgeIndex; /**< Index of this edge */
};
@ -101,12 +101,12 @@ class MS_API ModelBuilder {
/// \param[in] op Define the primitive to be added.
/// \param[in] inputs Define input edge of primitive to be added.
///
/// \return Id of the primitive added.
/// \return ID of the primitive added.
virtual std::string AddOp(const lite::Primitive &op, const std::vector<OutEdge> &inputs) = 0;
/// \brief Finish constructing the model.
///
/// \return A pointer of MindSpore Lite Model.
/// \return the pointer of MindSpore Lite Model.
virtual Model *Construct();
};
} // namespace lite

@ -25,7 +25,7 @@
namespace mindspore {
#define MS_API __attribute__((visibility("default")))
namespace tensor {
/// \brief MSTensor defined by MindSpore Lite.
/// \brief MSTensor defined tensor in MindSpore Lite.
class MS_API MSTensor {
public:
/// \brief Constructor of MindSpore Lite MSTensor.
@ -41,7 +41,7 @@ class MS_API MSTensor {
/// \note TypeId is defined in mindspore/mindspore/core/ir/dtype/type_id.h. Only number types in TypeId enum is
/// suitable for MSTensor.
///
/// \return A pointer of MSTensor.
/// \return the pointer of MSTensor.
static MSTensor *CreateTensor(TypeId data_type, const std::vector<int> &shape);
/// \brief Destructor of MindSpore Lite Model.
@ -69,7 +69,7 @@ class MS_API MSTensor {
/// \brief Set shape for the MindSpore Lite MSTensor.
///
/// \param[in] shape Define A vector of int as shape to be set into the MindSpore Lite MSTensor.
/// \param[in] shape Define a vector of int as shape to be set into the MindSpore Lite MSTensor.
///
/// \return size of shape of the MindSpore Lite MSTensor after set.
virtual size_t set_shape(const std::vector<int> &shape) = 0;
@ -96,15 +96,13 @@ class MS_API MSTensor {
/// \return Byte size of data in MSTensor.
virtual size_t Size() const = 0;
/// \brief Get pointer of data in MSTensor.
/// \brief Get the pointer of data in MSTensor.
///
/// \note The data pointer can be used to both write or read data in MSTensor.
///
/// \return A pointer points to data in MSTensor.
/// \return the pointer points to data in MSTensor.
virtual void *MutableData() const = 0;
};
using MultiTensor = std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>>;
} // namespace tensor
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_

@ -20,11 +20,11 @@
#include "src/common/ms_tensor_utils.h"
namespace mindspore::lite {
int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Tensor *> &outputs,
int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor::Tensor *> &out_tensors,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
const session::KernelCallBack &before, const session::KernelCallBack &after) {
MS_ASSERT(nullptr != allocator);
for (auto &inTensor : inputs) {
for (auto &inTensor : in_tensors) {
if (inTensor == nullptr) {
MS_LOG(ERROR) << "Graph input tensor is nullptr";
return RET_ERROR;
@ -39,31 +39,31 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
MS_ASSERT(nullptr != kernel);
if (before != nullptr) {
if (!before(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()),
{kernel->Name(), kernel->type_str()})) {
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->Name();
if (!before(PackToMSTensors(kernel->in_tensors()), PackToMSTensors(kernel->out_tensors()),
{kernel->name(), kernel->type_str()})) {
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name();
}
}
auto ret = kernel->Run();
if (0 != ret) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->Name();
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return ret;
}
if (after != nullptr) {
if (!after(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()),
{kernel->Name(), kernel->type_str()})) {
MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->Name();
if (!after(PackToMSTensors(kernel->in_tensors()), PackToMSTensors(kernel->out_tensors()),
{kernel->name(), kernel->type_str()})) {
MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name();
}
}
for (auto input_kernel : kernel->GetInKernels()) {
for (auto input_kernel : kernel->in_kernels()) {
MS_ASSERT(input_kernel != nullptr);
if (input_kernel->is_model_output()) {
continue;
}
ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed";
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
}
}
}

@ -29,7 +29,7 @@ class Executor {
int Prepare(std::vector<kernel::LiteKernel *> &kernels) { return 0; }
int Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Tensor *> &outputs,
int Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor::Tensor *> &out_tensors,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr);
@ -39,9 +39,6 @@ class Executor {
int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr);
int TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr);
protected:
Context *context = nullptr;
};
} // namespace mindspore::lite

@ -33,8 +33,8 @@ KernelFactory *KernelFactory::GetInstance() {
return &instance;
}
LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &inputs,
const std::vector<tensor::Tensor *> &outputs, const lite::Primitive *primitive,
LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
@ -45,7 +45,7 @@ LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &inputs
}
auto creator = KernelRegistry::GetInstance()->GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(inputs, outputs, parameter, ctx, key, primitive);
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive);
return kernel;
}
return nullptr;

@ -31,8 +31,8 @@ class KernelFactory {
virtual ~KernelFactory();
static KernelFactory *GetInstance();
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &inputs,
const std::vector<tensor::Tensor *> &outputs, const lite::Primitive *primitive,
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key);
};
} // namespace mindspore::lite

@ -20,13 +20,13 @@
namespace mindspore::kernel {
void LiteKernel::InitOutTensorRefCount() {
for (auto *tensor : this->outputs_) {
tensor->SetRefCount(this->out_kernel_.size());
for (auto *tensor : this->out_tensors_) {
tensor->SetRefCount(this->out_kernels_.size());
}
}
int LiteKernel::DecOutTensorRefCount() {
for (auto *tensor : this->outputs_) {
for (auto *tensor : this->out_tensors_) {
tensor->decRefCount();
if (0 >= tensor->RefCount()) {
auto ret = tensor->FreeData();
@ -43,7 +43,7 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> input_kernels;
for (const auto kernel : kernels) {
for (auto input : kernel->GetInKernels()) {
for (auto input : kernel->in_kernels()) {
auto iter = std::find(kernels.begin(), kernels.end(), input);
if (iter == kernels.end()) {
input_kernels.emplace_back(input);
@ -57,7 +57,7 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels(
const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> output_kernels;
for (const auto kernel : kernels) {
for (const auto output : kernel->GetOutKernels()) {
for (const auto output : kernel->out_kernels()) {
auto iter = std::find(kernels.begin(), kernels.end(), output);
if (iter == kernels.end()) {
output_kernels.emplace_back(output);
@ -72,11 +72,11 @@ std::vector<lite::tensor::Tensor *> LiteKernelUtil::SubgraphInputTensors(
std::vector<lite::tensor::Tensor *> input_tensors;
std::vector<lite::tensor::Tensor *> all_output_tensors;
for (const auto &kernel : kernels) {
all_output_tensors.insert(all_output_tensors.end(), kernel->GetOutputs().begin(), kernel->GetOutputs().end());
all_output_tensors.insert(all_output_tensors.end(), kernel->out_tensors().begin(), kernel->out_tensors().end());
}
std::vector<kernel::LiteKernel *> input_kernels = SubgraphInputKernels(kernels);
for (const auto &kernel : input_kernels) {
for (const auto &tensor : kernel->GetInputs()) {
for (const auto &tensor : kernel->in_tensors()) {
auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor);
if (iter == all_output_tensors.end() && tensor->Data() == nullptr) {
input_tensors.emplace_back(tensor);
@ -91,11 +91,11 @@ std::vector<lite::tensor::Tensor *> LiteKernelUtil::SubgraphOutputTensors(
std::vector<lite::tensor::Tensor *> output_tensors;
std::vector<lite::tensor::Tensor *> all_input_tensors;
for (const auto &kernel : kernels) {
all_input_tensors.insert(all_input_tensors.end(), kernel->GetInputs().begin(), kernel->GetInputs().end());
all_input_tensors.insert(all_input_tensors.end(), kernel->in_tensors().begin(), kernel->in_tensors().end());
}
std::vector<kernel::LiteKernel *> output_kernels = SubgraphOutputKernels(kernels);
for (const auto &kernel : output_kernels) {
for (const auto &tensor : kernel->GetOutputs()) {
for (const auto &tensor : kernel->out_tensors()) {
auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor);
if (iter == all_input_tensors.end()) {
output_tensors.emplace_back(tensor);
@ -111,13 +111,13 @@ void LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> &k
if (search_kernel == kernel) {
continue;
}
for (auto *tensor : kernel->GetInputs()) {
if (lite::IsContain(search_kernel->GetOutputs(), tensor)) {
for (auto *tensor : kernel->in_tensors()) {
if (lite::IsContain(search_kernel->out_tensors(), tensor)) {
kernel->AddInKernel(search_kernel);
}
}
for (auto *tensor : kernel->GetOutputs()) {
if (lite::IsContain(search_kernel->GetInputs(), tensor)) {
for (auto *tensor : kernel->out_tensors()) {
if (lite::IsContain(search_kernel->in_tensors(), tensor)) {
kernel->AddOutKernel(search_kernel);
}
}

@ -57,102 +57,109 @@ struct KernelKey {
class LiteKernel {
public:
LiteKernel() = default;
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &in_tensors,
const std::vector<lite::tensor::Tensor *> &out_tensors, const lite::Context *ctx,
const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) {
if (opParameter && ctx) {
opParameter->thread_num_ = ctx->thread_num_;
: op_parameter_(parameter),
in_tensors_(in_tensors),
out_tensors_(out_tensors),
primitive_(primitive),
context_(ctx) {
if (op_parameter_ && ctx) {
op_parameter_->thread_num_ = ctx->thread_num_;
}
this->in_kernel_.clear();
this->out_kernel_.clear();
this->in_kernels_.clear();
this->out_kernels_.clear();
}
virtual ~LiteKernel() { delete opParameter; }
virtual ~LiteKernel() { delete op_parameter_; }
virtual int Prepare() {
if (!InferShapeDone()) {
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_);
if (need_reinit) {
(const_cast<lite::Primitive *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (need_reinit_) {
Init();
}
}
auto &outputs = this->GetOutputs();
auto &outputs = this->out_tensors();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
output->MallocData();
}
return RET_OK;
}
virtual int Init() { return -1; }
virtual int ReSize() { return -1; }
virtual int Run() { return -1; }
std::string Name() { return this->name; }
virtual void train() { train_mode = true; }
virtual bool is_train() { return train_mode == true; }
virtual void eval() { train_mode = false; }
virtual bool is_eval() { return train_mode == false; }
void set_name(const std::string &name) { this->name = name; }
std::string name() { return this->name_; }
void set_is_model_output(bool is_model_output) { this->is_model_output_ = is_model_output; }
virtual void train() { train_mode_ = true; }
virtual bool is_train() { return train_mode_; }
virtual void eval() { train_mode_ = false; }
virtual bool is_eval() { return !train_mode_; }
bool is_model_output() { return this->is_model_output_; }
void set_name(const std::string &name) { this->name_ = name; }
schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; }
void set_is_model_output(bool is_model_output) { this->is_model_output_ = is_model_output; }
bool is_model_output() const { return this->is_model_output_; }
std::string type_str() {
return this->opParameter ? schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_)
: "ERROR:undefined primitive!";
schema::PrimitiveType Type() {
return (this->op_parameter_ != nullptr) ? schema::PrimitiveType(this->op_parameter_->type_)
: schema::PrimitiveType_NONE;
}
void SetInputs(const std::vector<lite::tensor::Tensor *> &inputs) { this->inputs_ = inputs; }
std::string type_str() { return schema::EnumNamePrimitiveType(this->Type()); }
void SetOutputs(const std::vector<lite::tensor::Tensor *> &outputs) { this->outputs_ = outputs; }
void set_in_tensors(const std::vector<lite::tensor::Tensor *> &in_tensors) { this->in_tensors_ = in_tensors; }
std::vector<lite::tensor::Tensor *> &GetInputs() { return this->inputs_; }
void set_out_tensors(const std::vector<lite::tensor::Tensor *> &out_tensors) { this->out_tensors_ = out_tensors; }
std::vector<lite::tensor::Tensor *> &GetOutputs() { return this->outputs_; }
std::vector<lite::tensor::Tensor *> &in_tensors() { return this->in_tensors_; }
void AddInKernel(LiteKernel *kernel) { this->in_kernel_.emplace_back(kernel); }
std::vector<lite::tensor::Tensor *> &out_tensors() { return this->out_tensors_; }
void AddOutKernel(LiteKernel *kernel) { this->out_kernel_.emplace_back(kernel); }
void AddInKernel(LiteKernel *kernel) { this->in_kernels_.emplace_back(kernel); }
std::vector<LiteKernel *> &GetInKernels() { return this->in_kernel_; }
void AddOutKernel(LiteKernel *kernel) { this->out_kernels_.emplace_back(kernel); }
std::vector<LiteKernel *> &GetOutKernels() { return this->out_kernel_; }
std::vector<LiteKernel *> &in_kernels() { return this->in_kernels_; }
std::vector<LiteKernel *> &out_kernels() { return this->out_kernels_; }
void InitOutTensorRefCount();
int DecOutTensorRefCount();
const KernelKey Desc() const { return desc; }
KernelKey desc() const { return desc_; }
void set_desc(const KernelKey kernel_key) { desc = kernel_key; }
void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; }
void SetNeedReInit() { need_reinit = true; }
void set_need_reinit() { need_reinit_ = true; }
protected:
bool InferShapeDone() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
return false;
}
return true;
}
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()) && true; }
KernelKey desc;
std::string name;
OpParameter *opParameter = nullptr;
KernelKey desc_;
std::string name_;
OpParameter *op_parameter_ = nullptr;
const lite::Primitive *primitive_ = nullptr;
const lite::Context *context_ = nullptr;
// tensor will free in ~lite_session()
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
std::vector<LiteKernel *> in_kernel_;
std::vector<LiteKernel *> out_kernel_;
bool train_mode = false;
bool need_reinit = false;
std::vector<lite::tensor::Tensor *> in_tensors_;
std::vector<lite::tensor::Tensor *> out_tensors_;
std::vector<LiteKernel *> in_kernels_;
std::vector<LiteKernel *> out_kernels_;
bool train_mode_ = false;
bool need_reinit_ = false;
bool is_model_output_ = false;
};

@ -74,46 +74,46 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
}
}
this->tensors.emplace_back(dstTensor);
this->tensors_.emplace_back(dstTensor);
}
return RET_OK;
}
void LiteSession::InitGraphInputTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->inputs.empty());
MS_ASSERT(this->inputs_.empty());
MS_ASSERT(meta_graph != nullptr);
for (size_t i = 0; i < meta_graph->inputIndex()->size(); i++) {
auto in_tensor_idx = size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i));
MS_ASSERT(in_tensor_idx < this->tensors.size());
auto *in_tensor = this->tensors.at(in_tensor_idx);
MS_ASSERT(in_tensor_idx < this->tensors_.size());
auto *in_tensor = this->tensors_.at(in_tensor_idx);
MS_ASSERT(in_tensor != nullptr);
this->inputs.emplace_back(in_tensor);
this->inputs_.emplace_back(in_tensor);
}
}
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->outputs.empty());
MS_ASSERT(this->outputs_.empty());
MS_ASSERT(meta_graph != nullptr);
for (size_t i = 0; i < meta_graph->outputIndex()->size(); i++) {
auto out_tensor_idx = size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i));
MS_ASSERT(out_tensor_idx < this->tensors.size());
auto *out_tensor = this->tensors.at(out_tensor_idx);
MS_ASSERT(out_tensor_idx < this->tensors_.size());
auto *out_tensor = this->tensors_.at(out_tensor_idx);
MS_ASSERT(out_tensor != nullptr);
this->outputs.emplace_back(out_tensor);
this->outputs_.emplace_back(out_tensor);
}
}
void LiteSession::InitGraphInputMap(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->input_map.empty());
MS_ASSERT(this->input_map_.empty());
MS_ASSERT(meta_graph != nullptr);
auto graph_input_node_indexes = GetGraphInputNodes(meta_graph);
for (auto in_node_index : graph_input_node_indexes) {
auto *in_node = meta_graph->nodes()->GetAs<schema::CNode>(in_node_index);
MS_ASSERT(nullptr != in_node);
MS_ASSERT(this->input_map.find(in_node->name()->str()) == this->input_map.end());
MS_ASSERT(this->input_map_.find(in_node->name()->str()) == this->input_map_.end());
for (size_t i = 0; i < in_node->inputIndex()->size(); i++) {
auto in_tensor_index = size_t(in_node->inputIndex()->GetAs<uint32_t>(i));
bool is_graph_input = false;
@ -126,25 +126,25 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
if (!is_graph_input) {
continue;
}
MS_ASSERT(in_tensor_index < this->tensors.size());
auto *in_tensor = this->tensors.at(in_tensor_index);
MS_ASSERT(in_tensor_index < this->tensors_.size());
auto *in_tensor = this->tensors_.at(in_tensor_index);
MS_ASSERT(in_tensor != nullptr);
auto *ms_tensor = new tensor::LiteTensor(in_tensor);
MS_ASSERT(nullptr != ms_tensor);
this->input_map[in_node->name()->str()].emplace_back(ms_tensor);
this->input_map_[in_node->name()->str()].emplace_back(ms_tensor);
}
}
}
void LiteSession::InitGraphOutputMap(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->output_map.empty());
MS_ASSERT(this->output_map_.empty());
MS_ASSERT(meta_graph != nullptr);
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (auto out_node_index : graph_output_node_indexes) {
auto *out_node = meta_graph->nodes()->GetAs<schema::CNode>(out_node_index);
MS_ASSERT(nullptr != out_node);
MS_ASSERT(this->output_map.find(out_node->name()->str()) == this->output_map.end());
MS_ASSERT(this->output_map_.find(out_node->name()->str()) == this->output_map_.end());
for (size_t i = 0; i < out_node->outputIndex()->size(); i++) {
auto out_tensor_index = size_t(out_node->outputIndex()->GetAs<uint32_t>(i));
bool is_graph_output = false;
@ -157,12 +157,12 @@ void LiteSession::InitGraphOutputMap(const lite::Model *model) {
if (!is_graph_output) {
continue;
}
MS_ASSERT(out_tensor_index < this->tensors.size());
auto *out_tensor = this->tensors.at(out_tensor_index);
MS_ASSERT(out_tensor_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(out_tensor_index);
MS_ASSERT(out_tensor != nullptr);
auto *ms_tensor = new tensor::LiteTensor(out_tensor);
MS_ASSERT(nullptr != ms_tensor);
this->output_map[out_node->name()->str()].emplace_back(ms_tensor);
this->output_map_[out_node->name()->str()].emplace_back(ms_tensor);
}
}
}
@ -191,7 +191,7 @@ int LiteSession::CompileGraph(Model *model) {
// scheduler kernels
Scheduler scheduler(context_);
ret = scheduler.Schedule(model, &tensors, &kernels);
ret = scheduler.Schedule(model, &tensors_, &kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule kernels failed: " << ret;
return ret;
@ -202,7 +202,7 @@ int LiteSession::CompileGraph(Model *model) {
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const {
std::vector<mindspore::tensor::MSTensor *> ret;
for (auto &iter : this->input_map) {
for (auto &iter : this->input_map_) {
auto &node_input_tensors = iter.second;
for (auto tensor : node_input_tensors) {
if (!IsContain(ret, tensor)) {
@ -219,14 +219,14 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
context_->running_ = true;
Executor executor;
if (before == nullptr && after == nullptr) {
return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get());
return executor.Run(this->inputs_, this->outputs_, this->kernels_, this->context_->allocator.get());
} else {
return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get(), before, after);
return executor.Run(this->inputs_, this->outputs_, this->kernels_, this->context_->allocator.get(), before, after);
}
}
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> LiteSession::GetOutputs() const {
return this->output_map;
return this->output_map_;
}
int LiteSession::Init(Context *context) {
@ -252,46 +252,46 @@ int LiteSession::Init(Context *context) {
return RET_OK;
}
void LiteSession::BindThread(bool ifBind) {
void LiteSession::BindThread(bool if_bind) {
if (this->context_->cpu_bind_mode_ != NO_BIND) {
DoAllThreadBind(ifBind, static_cast<int>(this->context_->cpu_bind_mode_));
DoAllThreadBind(if_bind, static_cast<int>(this->context_->cpu_bind_mode_));
}
}
LiteSession::~LiteSession() {
for (auto *tensor : tensors) {
for (auto *tensor : tensors_) {
// weight data can not be to free, we will free weight data when freeing meta_graph
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs, tensor)) {
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor)) {
tensor->SetData(nullptr);
}
delete tensor;
}
// tensor::Tensor * in input_map output_map are freed in tensors
for (auto iter : this->input_map) {
for (auto iter : this->input_map_) {
for (auto *ms_tensor : iter.second) {
((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr);
delete ms_tensor;
}
iter.second.clear();
}
input_map.clear();
for (auto iter : this->output_map) {
input_map_.clear();
for (auto iter : this->output_map_) {
for (auto *ms_tensor : iter.second) {
((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr);
delete ms_tensor;
}
iter.second.clear();
}
output_map.clear();
for (auto *kernel : kernels) {
output_map_.clear();
for (auto *kernel : kernels_) {
delete kernel;
}
delete this->context_;
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputsByName(const std::string &name) const {
auto ret = input_map.find(name);
if (ret == input_map.end()) {
auto ret = input_map_.find(name);
if (ret == input_map_.end()) {
MS_LOG(WARNING) << "Node " << name << " is not an input node";
std::vector<mindspore::tensor::MSTensor *> empty_ret;
return empty_ret;
@ -300,8 +300,8 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputsByName(const st
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputsByName(const std::string &name) const {
auto ret = output_map.find(name);
if (ret == output_map.end()) {
auto ret = output_map_.find(name);
if (ret == output_map_.end()) {
MS_LOG(WARNING) << "Node " << name << " is not an output node";
std::vector<mindspore::tensor::MSTensor *> empty_ret;
return empty_ret;

@ -38,7 +38,7 @@ class LiteSession : public session::LiteSession {
int Init(Context *context);
void BindThread(bool ifBind) override;
void BindThread(bool if_bind) override;
int CompileGraph(Model *model) override;
@ -68,16 +68,16 @@ class LiteSession : public session::LiteSession {
protected:
Context *context_ = nullptr;
std::vector<kernel::LiteKernel *> kernels;
std::vector<tensor::Tensor *> tensors;
std::vector<kernel::LiteKernel *> kernels_;
std::vector<tensor::Tensor *> tensors_;
// graph input tensors
std::vector<tensor::Tensor *> inputs;
std::vector<tensor::Tensor *> inputs_;
// graph output tensors
std::vector<tensor::Tensor *> outputs;
std::vector<tensor::Tensor *> outputs_;
// graph input node name -- input tensors
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_;
// graph output node name -- output tensors
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> output_map;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> output_map_;
};
} // namespace lite
} // namespace mindspore

@ -30,9 +30,7 @@ Model *Model::Import(const char *model_buf, size_t size) {
return model;
}
Model::~Model() {
delete(this->model_impl_);
}
Model::~Model() { delete (this->model_impl_); }
lite::Primitive *Model::GetOp(const std::string &name) const {
MS_EXCEPTION_IF_NULL(model_impl_);
@ -46,7 +44,7 @@ void Model::FreeMetaGraph() {
const schema::MetaGraph *Model::GetMetaGraph() const {
MS_EXCEPTION_IF_NULL(model_impl_);
return model_impl_->GetMetaGraph();
return model_impl_->meta_graph();
}
ModelImpl *Model::model_impl() {

File diff suppressed because it is too large Load Diff

@ -30,25 +30,24 @@ class ModelImpl {
static ModelImpl *Import(const char *model_buf, size_t size);
ModelImpl() = default;
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
meta_graph = schema::GetMetaGraph(model_buf);
meta_graph_ = schema::GetMetaGraph(model_buf);
}
virtual ~ModelImpl();
lite::Primitive *GetOp(const std::string &name) const;
const schema::MetaGraph *GetMetaGraph() const;
const schema::MetaGraph *meta_graph() const;
void FreeMetaGraph();
int BuildOps();
protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim);
lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim);
protected:
const char *model_buf_;
size_t buf_size_;
const schema::MetaGraph *meta_graph = nullptr;
std::map<std::string, lite::Primitive *> ops;
const schema::MetaGraph *meta_graph_ = nullptr;
std::map<std::string, lite::Primitive *> ops_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_MODEL_H_

@ -33,8 +33,8 @@ using mindspore::schema::PrimitiveType_ArgMin;
namespace mindspore::kernel {
int ArgMinMaxBaseCPUKernel::Init() {
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
switch (opParameter->type_) {
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
switch (op_parameter_->type_) {
case PrimitiveType_ArgMax:
param->get_max_ = true;
break;
@ -42,7 +42,7 @@ int ArgMinMaxBaseCPUKernel::Init() {
param->get_max_ = false;
break;
default:
MS_LOG(ERROR) << "Unexpected type " << opParameter->type_;
MS_LOG(ERROR) << "Unexpected type " << op_parameter_->type_;
return RET_ERROR;
}
@ -50,9 +50,9 @@ int ArgMinMaxBaseCPUKernel::Init() {
}
int ArgMinMaxBaseCPUKernel::ReSize() {
auto in_shape = inputs_.at(0)->shape();
auto in_shape = in_tensors_.at(0)->shape();
auto dims_size = in_shape.size();
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_;
param->axis_ = axis;
param->dims_size_ = dims_size;
@ -75,25 +75,25 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
}
}
ComputeStrides(in_shape.data(), param->in_strides_, in_shape.size());
auto out_shape = outputs_.at(0)->shape();
auto out_shape = out_tensors_.at(0)->shape();
ComputeStrides(out_shape.data(), param->out_strides_, out_shape.size());
return RET_OK;
}
int ArgMinMaxBaseCPUKernel::Run() {
auto input = inputs_.at(0);
auto input = in_tensors_.at(0);
auto input_data = reinterpret_cast<const void *>(inputs_.at(0)->Data());
auto output_data = outputs_.at(0)->Data();
auto input_data = reinterpret_cast<const void *>(in_tensors_.at(0)->Data());
auto output_data = out_tensors_.at(0)->Data();
auto shape = input->shape().data();
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
ArgMinMax(input_data, output_data, reinterpret_cast<const int *>(shape), param);
return RET_OK;
}
void ArgMinMaxBaseCPUKernel::FreeTmpMemory() {
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
if (param->arg_elements_ == nullptr) {
return;
}

@ -30,7 +30,7 @@ using mindspore::schema::PrimitiveType_BatchToSpace;
namespace mindspore::kernel {
int BatchToSpaceBaseCPUKernel::Init() {
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
if (param->crops_[i] != 0) {
no_crop_ = false;
@ -40,7 +40,7 @@ int BatchToSpaceBaseCPUKernel::Init() {
}
int BatchToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
if (in_tensors_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}

@ -32,7 +32,7 @@ namespace mindspore::kernel {
int ConcatBaseCPUKernel::Init() { return RET_OK; }
int ConcatBaseCPUKernel::ReSize() {
axis_ = concat_param_->axis_ >= 0 ? concat_param_->axis_ : inputs_.front()->shape().size() + concat_param_->axis_;
axis_ = concat_param_->axis_ >= 0 ? concat_param_->axis_ : in_tensors_.front()->shape().size() + concat_param_->axis_;
return RET_OK;
}

@ -31,7 +31,7 @@ class ConcatBaseCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {
concat_param_ = reinterpret_cast<ConcatParameter *>(opParameter);
concat_param_ = reinterpret_cast<ConcatParameter *>(op_parameter_);
}
virtual ~ConcatBaseCPUKernel() = default;

@ -77,8 +77,8 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() {
}
int ConvolutionBaseCPUKernel::Init() {
auto input = this->inputs_.front();
auto output = this->outputs_.front();
auto input = this->in_tensors_.front();
auto output = this->out_tensors_.front();
conv_param_->input_batch_ = input->Batch();
conv_param_->input_h_ = input->Height();
conv_param_->input_w_ = input->Width();
@ -118,9 +118,9 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
return RET_ERROR;
}
}
auto input_tensor = inputs_.at(kInputIndex);
auto weight_tensor = inputs_.at(kWeightIndex);
auto output_tensor = outputs_.at(kOutputIndex);
auto input_tensor = in_tensors_.at(kInputIndex);
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto output_tensor = out_tensors_.at(kOutputIndex);
auto input_quant_arg = input_tensor->GetQuantParams().front();
auto weight_quant_arg = weight_tensor->GetQuantParams().front();
auto output_quant_arg = output_tensor->GetQuantParams().front();

@ -40,8 +40,8 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {
opParameter->thread_num_ = ctx->thread_num_;
conv_param_ = reinterpret_cast<ConvParameter *>(opParameter);
op_parameter_->thread_num_ = ctx->thread_num_;
conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_);
}
~ConvolutionBaseCPUKernel() override;

@ -34,27 +34,27 @@ namespace mindspore::kernel {
int DepthToSpaceBaseCPUKernel::Init() { return RET_OK; }
int DepthToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
if (in_tensors_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(op_parameter_);
if (param->block_size_ <= 0) {
MS_LOG(ERROR) << "Input block_size should > 0!";
return RET_PARAM_INVALID;
}
auto shape_size = inputs_[0]->shape().size();
auto shape_size = in_tensors_[0]->shape().size();
if (shape_size != DIMENSION_4D) {
MS_LOG(ERROR) << "Input shape size should be " << DIMENSION_4D;
return RET_PARAM_INVALID;
}
int32_t in_strides[DIMENSION_4D];
ComputeStrides(const_cast<int *>(inputs_[0]->shape().data()), in_strides, shape_size);
ComputeStrides(const_cast<int *>(in_tensors_[0]->shape().data()), in_strides, shape_size);
param->in_stride_dim0_ = in_strides[0];
param->in_stride_dim1_ = in_strides[1];
param->in_stride_dim2_ = in_strides[2];
int32_t out_strides[DIMENSION_4D];
ComputeStrides(const_cast<int *>(outputs_[0]->shape().data()), out_strides, shape_size);
ComputeStrides(const_cast<int *>(out_tensors_[0]->shape().data()), out_strides, shape_size);
param->out_stride_dim0_ = out_strides[0];
param->out_stride_dim1_ = out_strides[1];
param->out_stride_dim2_ = out_strides[2];

@ -31,7 +31,7 @@ class FullconnectionBaseCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {
fc_param_ = reinterpret_cast<MatMulParameter *>(opParameter);
fc_param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~FullconnectionBaseCPUKernel() = default;

@ -59,4 +59,3 @@ LayoutConvertor LayoutTransform(TypeId data_type, schema::Format src_format, sch
}
}
} // namespace mindspore::kernel

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save