!5928 [MS][LITE]optimize interface and data structure
Merge pull request !5928 from zhaizhiqiang/masterpull/5928/MERGE
commit
c621f6f751
@ -0,0 +1,31 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project (Lite_Internal)
|
||||
set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../)
|
||||
|
||||
include_directories(${TOP_DIR})
|
||||
|
||||
file(GLOB_RECURSE C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
file(GLOB KERNEL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c)
|
||||
|
||||
set(CCSRC
|
||||
${TOP_DIR}/src/common/log_adapter.cc
|
||||
${TOP_DIR}/src/runtime/allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/gvar/logging_level.cc
|
||||
)
|
||||
|
||||
if (PLATFORM_ARM64)
|
||||
# assembly
|
||||
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/assembly/arm64/*.s
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/assembly/arm64/*.S)
|
||||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
|
||||
add_library(mslite_internal SHARED ${C_SRC} ${CCSRC} ${KERNEL_SRC})
|
||||
target_link_libraries(mslite_internal log)
|
||||
endif()
|
||||
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_INCLUDE_CONTEXT_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_INCLUDE_CONTEXT_H_
|
||||
|
||||
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
|
||||
typedef enum {
|
||||
MID_CPU = -1, /**< bind middle cpu first */
|
||||
HIGHER_CPU = 1, /**< bind higher cpu first */
|
||||
NO_BIND = 0 /**< no bind */
|
||||
} CpuBindMode;
|
||||
|
||||
/// \brief DeviceType defined for holding user's preferred backend.
|
||||
typedef enum {
|
||||
DT_CPU, /**< CPU device type */
|
||||
DT_GPU, /**< GPU device type */
|
||||
DT_NPU /**< NPU device type, not supported yet */
|
||||
} DeviceType;
|
||||
|
||||
/// \brief Context defined for holding environment variables during runtime.
|
||||
typedef struct {
|
||||
bool float16_priority = false; /**< prior enable float16 inference */
|
||||
DeviceType device_type_ = DT_CPU;
|
||||
int thread_num_ = 2; /**< thread number config for thread pool */
|
||||
} Context;
|
||||
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_INCLUDE_ERRORCODE_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_INCLUDE_ERRORCODE_H_
|
||||
|
||||
/// \brief STATUS defined for holding error code in MindSpore Lite.
|
||||
using STATUS = int;
|
||||
|
||||
/* Success */
|
||||
constexpr int RET_OK = 0; /**< No error occurs. */
|
||||
|
||||
/* Common error code, range: [-1, -100]*/
|
||||
constexpr int RET_ERROR = -1; /**< Common error code. */
|
||||
constexpr int RET_NULL_PTR = -2; /**< NULL pointer returned.*/
|
||||
constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/
|
||||
constexpr int RET_NO_CHANGE = -4; /**< No change. */
|
||||
constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */
|
||||
constexpr int RET_MEMORY_FAILED = -6; /**< Fail to create memory. */
|
||||
|
||||
/* Executor error code, range: [-101,-200] */
|
||||
constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to check range. */
|
||||
constexpr int RET_INPUT_TENSOR_ERROR = -102; /**< Failed to check input tensor. */
|
||||
constexpr int RET_REENTRANT_ERROR = -103; /**< Exist executor running. */
|
||||
|
||||
/* Graph error code, range: [-201,-300] */
|
||||
constexpr int RET_GRAPH_FILE_ERR = -201; /**< Failed to verify graph file. */
|
||||
|
||||
/* Node error code, range: [-301,-400] */
|
||||
constexpr int RET_NOT_FIND_OP = -301; /**< Failed to find operator. */
|
||||
constexpr int RET_INVALID_OP_NAME = -302; /**< Invalid operator name. */
|
||||
constexpr int RET_INVALID_OP_ATTR = -303; /**< Invalid operator attr. */
|
||||
constexpr int RET_OP_EXECUTE_FAILURE = -304; /**< Failed to execution operator. */
|
||||
|
||||
/* Tensor error code, range: [-401,-500] */
|
||||
constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */
|
||||
|
||||
/* InferShape error code, range: [-501,-600] */
|
||||
constexpr int RET_INFER_ERR = -501; /**< Failed to infer shape. */
|
||||
constexpr int RET_INFER_INVALID = -502; /**< Invalid infer shape before runtime. */
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_ERRORCODE_H_
|
@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_INCLUDE_LITE_SESSION_H
|
||||
#define MINDSPORE_LITE_INTERNAL_INCLUDE_LITE_SESSION_H
|
||||
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "internal/include/model.h"
|
||||
#include "internal/include/context.h"
|
||||
#include "internal/include/lite_utils.h"
|
||||
|
||||
/// \brief LiteSession defined session in MindSpore Lite for compiling Model and forwarding model.
|
||||
typedef struct LiteSession {
|
||||
/// \brief Static method to create a LiteSession pointer.
|
||||
///
|
||||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(Context *context);
|
||||
|
||||
/// \brief Compile MindSpore Lite model.
|
||||
///
|
||||
/// \note CompileGraph should be called before RunGraph.
|
||||
///
|
||||
/// \param[in] model Define the model to be compiled.
|
||||
///
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h.
|
||||
int CompileGraph(Model *model);
|
||||
|
||||
/// \brief Get input MindSpore Lite MSTensors of model.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
TensorPtrVector GetInputs() const;
|
||||
|
||||
/// \brief Get input MindSpore Lite MSTensors of model by node name.
|
||||
///
|
||||
/// \param[in] node_name Define node name.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
TensorPtrVector GetInputsByName(const String &node_name) const;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by node name.
|
||||
///
|
||||
/// \param[in] node_name Define node name.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
TensorPtrVector GetOutputsByNodeName(const String &node_name) const;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
|
||||
///
|
||||
/// \return The map of output tensor name and MindSpore Lite MSTensor.
|
||||
TensorPtrVector GetOutputs() const;
|
||||
|
||||
/// \brief Get name of output tensors of model compiled by this session.
|
||||
///
|
||||
/// \return The vector of string as output tensor names in order.
|
||||
StringVector GetOutputTensorNames() const;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
|
||||
///
|
||||
/// \param[in] tensor_name Define tensor name.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite MSTensor.
|
||||
MSTensor *GetOutputByTensorName(const String &tensor_name) const;
|
||||
|
||||
/// \note RunGraph should be called after CompileGraph.
|
||||
int RunGraph();
|
||||
|
||||
/// \brief Resize inputs shape.
|
||||
///
|
||||
/// \param[in] inputs Define the new inputs shape.
|
||||
///
|
||||
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
||||
int Resize(const TensorPtrVector &inputs);
|
||||
} LiteSession;
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_INCLUDE_MODEL_H
|
||||
#define MINDSPORE_LITE_INTERNAL_INCLUDE_MODEL_H
|
||||
#include "internal/include/lite_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
using PrimitiveC = OpParameter;
|
||||
enum NodeType {
|
||||
NodeType_ValueNode = 0,
|
||||
NodeType_Parameter = 1,
|
||||
NodeType_CNode = 2,
|
||||
NodeType_MIN = NodeType_ValueNode,
|
||||
NodeType_MAX = NodeType_CNode
|
||||
};
|
||||
|
||||
typedef struct Node {
|
||||
String name_;
|
||||
NodeType node_type_;
|
||||
PrimitiveC *primitive_;
|
||||
Uint32Vector input_indices_;
|
||||
Uint32Vector output_indices_;
|
||||
} Node;
|
||||
|
||||
typedef struct Model {
|
||||
String name_;
|
||||
String version_;
|
||||
TensorPtrVector all_tensors_;
|
||||
Uint32Vector input_indices_;
|
||||
Uint32Vector output_indices_;
|
||||
NodePtrVector nodes_;
|
||||
char *buf;
|
||||
|
||||
/// \brief Static method to create a Model pointer.
|
||||
///
|
||||
/// \param[in] model_buf Define the buffer read from a model file.
|
||||
/// \param[in] size Define bytes number of model buffer.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite Model.
|
||||
static Model *Import(const char *model_buf, size_t size);
|
||||
|
||||
/// \brief Free all the temporary buffer
|
||||
void Free();
|
||||
} Model;
|
||||
|
||||
#endif // MINDSPORE_LITE_INTERNAL_INCLUDE_MODEL_H
|
@ -0,0 +1,142 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_INCLUDE_MS_TENSOR_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_INCLUDE_MS_TENSOR_H_
|
||||
|
||||
#include "internal/include/lite_utils.h"
|
||||
|
||||
enum TypeId : int {
|
||||
kTypeUnknown = 0,
|
||||
kMetaTypeBegin = kTypeUnknown,
|
||||
kMetaTypeType, // Type
|
||||
kMetaTypeAnything,
|
||||
kMetaTypeObject,
|
||||
kMetaTypeTypeType, // TypeType
|
||||
kMetaTypeProblem,
|
||||
kMetaTypeExternal,
|
||||
kMetaTypeNone,
|
||||
kMetaTypeNull,
|
||||
kMetaTypeEllipsis,
|
||||
kMetaTypeEnd,
|
||||
//
|
||||
// Object types
|
||||
//
|
||||
kObjectTypeBegin = kMetaTypeEnd,
|
||||
kObjectTypeNumber,
|
||||
kObjectTypeString,
|
||||
kObjectTypeList,
|
||||
kObjectTypeTuple,
|
||||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeRowTensorType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
kObjectTypeFunction,
|
||||
kObjectTypeJTagged,
|
||||
kObjectTypeSymbolicKeyType,
|
||||
kObjectTypeEnvType,
|
||||
kObjectTypeRefKey,
|
||||
kObjectTypeRef,
|
||||
kObjectTypeEnd,
|
||||
//
|
||||
// Number Types
|
||||
//
|
||||
kNumberTypeBegin = kObjectTypeEnd,
|
||||
kNumberTypeBool,
|
||||
kNumberTypeInt,
|
||||
kNumberTypeInt8,
|
||||
kNumberTypeInt16,
|
||||
kNumberTypeInt32,
|
||||
kNumberTypeInt64,
|
||||
kNumberTypeUInt,
|
||||
kNumberTypeUInt8,
|
||||
kNumberTypeUInt16,
|
||||
kNumberTypeUInt32,
|
||||
kNumberTypeUInt64,
|
||||
kNumberTypeFloat,
|
||||
kNumberTypeFloat16,
|
||||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeEnd
|
||||
};
|
||||
|
||||
enum Format {
|
||||
Format_NCHW = 0,
|
||||
Format_NHWC = 1,
|
||||
Format_NHWC4 = 2,
|
||||
Format_HWKC = 3,
|
||||
Format_HWCK = 4,
|
||||
Format_KCHW = 5,
|
||||
Format_CKHW = 6,
|
||||
Format_KHWC = 7,
|
||||
Format_CHWK = 8,
|
||||
Format_HW = 9,
|
||||
Format_HW4 = 10,
|
||||
Format_NC = 11,
|
||||
Format_NC4 = 12,
|
||||
Format_NC4HW4 = 100,
|
||||
Format_NUM_OF_FORMAT = 101,
|
||||
Format_MIN = Format_NCHW,
|
||||
Format_MAX = Format_NUM_OF_FORMAT
|
||||
};
|
||||
|
||||
typedef struct MSTensor {
|
||||
enum Category {
|
||||
CONST, // weight tensor
|
||||
VAR // activation tensor
|
||||
};
|
||||
void *data_ = NULL;
|
||||
void *device_data_ = NULL;
|
||||
TypeId data_type_;
|
||||
Format format_ = Format_NHWC;
|
||||
Category category_ = VAR;
|
||||
ShapeVector shape_ = {};
|
||||
size_t refCount = 0;
|
||||
|
||||
int32_t Batch() const;
|
||||
|
||||
int32_t Channel() const;
|
||||
|
||||
int32_t Height() const;
|
||||
|
||||
int32_t Width() const;
|
||||
|
||||
/// \brief Get size of the dimension of the MindSpore Lite MSTensor index by the parameter index.
|
||||
///
|
||||
/// \param[in] index Define index of dimension returned.
|
||||
///
|
||||
/// \return Size of dimension of the MindSpore Lite MSTensor.
|
||||
int DimensionSize(size_t index) const;
|
||||
|
||||
/// \brief Get number of element in MSTensor.
|
||||
///
|
||||
/// \return Number of element in MSTensor.
|
||||
int ElementsNum() const;
|
||||
|
||||
int ElementsC4Num() const;
|
||||
|
||||
/// \brief Get byte size of data in MSTensor.
|
||||
///
|
||||
/// \return Byte size of data in MSTensor.
|
||||
size_t Size() const;
|
||||
} MSTensor;
|
||||
|
||||
MSTensor *CreateTensor(TypeId data_type, const ShapeVector &shape);
|
||||
#endif // MINDSPORE_LITE_INCLUDE_MS_TENSOR_H_
|
@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "internal/include/lite_session.h"
|
||||
#include "internal/include/model.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
static Context *g_Ctx;
|
||||
static Model *g_Model;
|
||||
static LiteSession g_Session;
|
||||
static mindspore::lite::DefaultAllocator allocator;
|
||||
|
||||
LiteSession *LiteSession::CreateSession(Context *context) {
|
||||
g_Ctx = context;
|
||||
return &g_Session;
|
||||
}
|
||||
|
||||
int LiteSession::CompileGraph(Model *model) {
|
||||
g_Model = model;
|
||||
for (auto in : g_Model->input_indices_) {
|
||||
g_Model->all_tensors_[in]->data_ = allocator.Malloc(g_Model->all_tensors_[in]->Size());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
TensorPtrVector LiteSession::GetInputs() const {
|
||||
TensorPtrVector in(g_Model->input_indices_.size());
|
||||
// for(auto index : g_Model->input_indices_){
|
||||
// in.emplace_back(g_Model->all_tensors_[index]);
|
||||
// }
|
||||
return in;
|
||||
}
|
||||
|
||||
TensorPtrVector LiteSession::GetInputsByName(const String &node_name) const { return TensorPtrVector(); }
|
||||
|
||||
TensorPtrVector LiteSession::GetOutputsByNodeName(const String &node_name) const { return TensorPtrVector(); }
|
||||
|
||||
TensorPtrVector LiteSession::GetOutputs() const {
|
||||
TensorPtrVector out(g_Model->output_indices_.size());
|
||||
// for(auto index : g_Model->output_indices_){
|
||||
// out.emplace_back(g_Model->all_tensors_[index]);
|
||||
// }
|
||||
return out;
|
||||
}
|
||||
|
||||
int LiteSession::RunGraph() {
|
||||
// invoke nnacl kernel
|
||||
return 0;
|
||||
}
|
||||
|
||||
StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); }
|
||||
|
||||
MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; }
|
||||
|
||||
int LiteSession::Resize(const TensorPtrVector &inputs) { return 0; }
|
@ -0,0 +1,194 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "internal/include/ms_tensor.h"
|
||||
MSTensor *CreateTensor(TypeId data_type, const ShapeVector &shape) {
|
||||
MSTensor *tensor = new MSTensor();
|
||||
tensor->shape_ = shape;
|
||||
tensor->data_type_ = data_type;
|
||||
return tensor;
|
||||
}
|
||||
int MSTensor::ElementsNum() const { return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int>()); }
|
||||
|
||||
size_t MSTensor::Size() const {
|
||||
size_t size = 0;
|
||||
switch (this->data_type_) {
|
||||
case kNumberTypeFloat64:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
case kNumberTypeFloat:
|
||||
case kNumberTypeFloat32:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case kNumberTypeInt8:
|
||||
size = sizeof(int8_t);
|
||||
break;
|
||||
case kNumberTypeUInt8:
|
||||
size = sizeof(uint8_t);
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
size = sizeof(int16_t);
|
||||
break;
|
||||
case kNumberTypeInt16:
|
||||
size = sizeof(int16_t);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
size = sizeof(int32_t);
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
size = sizeof(int64_t);
|
||||
break;
|
||||
case kNumberTypeUInt16:
|
||||
size = sizeof(uint16_t);
|
||||
break;
|
||||
case kNumberTypeUInt32:
|
||||
size = sizeof(uint32_t);
|
||||
break;
|
||||
case kNumberTypeUInt64:
|
||||
size = sizeof(uint64_t);
|
||||
break;
|
||||
case kNumberTypeBool:
|
||||
size = sizeof(bool);
|
||||
break;
|
||||
default:
|
||||
std::cout << "Not support the type: " << this->data_type_;
|
||||
return 0;
|
||||
}
|
||||
size *= (format_ == Format::Format_NC4HW4 || format_ == Format::Format_NHWC4) ? ElementsC4Num() : ElementsNum();
|
||||
|
||||
return size;
|
||||
}
|
||||
int32_t MSTensor::Batch() const {
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
std::cout << "Unsupported tensor shape: " << this->shape_.size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
case Format::Format_NHWC:
|
||||
case Format::Format_NHWC4:
|
||||
case Format::Format_NCHW:
|
||||
case Format::Format_NC4HW4:
|
||||
case Format::Format_KCHW:
|
||||
case Format::Format_KHWC:
|
||||
case Format::Format_NC:
|
||||
case Format::Format_NC4:
|
||||
return this->shape_[0];
|
||||
case Format::Format_HWCK:
|
||||
case Format::Format_CHWK:
|
||||
return this->shape_[3];
|
||||
case Format::Format_HWKC:
|
||||
return this->shape_[2];
|
||||
case Format::Format_CKHW:
|
||||
return this->shape_[1];
|
||||
default:
|
||||
// std::cout << "Unsupported format: " << EnumNameFormat(this->format_);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t MSTensor::Channel() const {
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
std::cout << "Unsupported tensor shape: " << this->shape_.size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
case Format::Format_NCHW:
|
||||
case Format::Format_KCHW:
|
||||
case Format::Format_NC:
|
||||
case Format::Format_NC4:
|
||||
return this->shape_[1];
|
||||
case Format::Format_HWCK:
|
||||
return this->shape_[2];
|
||||
case Format::Format_HWKC:
|
||||
case Format::Format_NHWC:
|
||||
case Format::Format_NHWC4:
|
||||
case Format::Format_NC4HW4:
|
||||
case Format::Format_KHWC:
|
||||
return this->shape_[3];
|
||||
case Format::Format_CKHW:
|
||||
case Format::Format_CHWK:
|
||||
return this->shape_[0];
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t MSTensor::Height() const {
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
std::cout << "Unsupported tensor shape: " << this->shape_.size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
case Format::Format_NCHW:
|
||||
case Format::Format_KCHW:
|
||||
case Format::Format_CKHW:
|
||||
return this->shape_[2];
|
||||
case Format::Format_NHWC:
|
||||
case Format::Format_NHWC4:
|
||||
case Format::Format_NC4HW4:
|
||||
case Format::Format_KHWC:
|
||||
case Format::Format_CHWK:
|
||||
return this->shape_[1];
|
||||
case Format::Format_HWCK:
|
||||
case Format::Format_HWKC:
|
||||
case Format::Format_HW:
|
||||
case Format::Format_HW4:
|
||||
return this->shape_[0];
|
||||
default:
|
||||
// std::cout << "Unsupported format: " << EnumNameFormat(this->format_);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t MSTensor::Width() const {
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
std::cout << "Unsupported tensor shape: " << this->shape_.size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
case Format::Format_NCHW:
|
||||
case Format::Format_KCHW:
|
||||
case Format::Format_CKHW:
|
||||
return this->shape_[3];
|
||||
case Format::Format_KHWC:
|
||||
case Format::Format_NHWC:
|
||||
case Format::Format_NHWC4:
|
||||
case Format::Format_NC4HW4:
|
||||
case Format::Format_CHWK:
|
||||
return this->shape_[2];
|
||||
case Format::Format_HWCK:
|
||||
case Format::Format_HWKC:
|
||||
case Format::Format_HW:
|
||||
case Format::Format_HW4:
|
||||
return this->shape_[1];
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int MSTensor::ElementsC4Num() const {
|
||||
int result = 0;
|
||||
if (this->shape_.size() == 4) {
|
||||
result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4);
|
||||
} else if (this->shape_.size() == 2) {
|
||||
result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4);
|
||||
}
|
||||
return result;
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue