optimize custom vector

pull/6713/head
jianghui58 4 years ago
parent 93d742732e
commit dcea00321a

@ -17,26 +17,24 @@
#define MINDSPORE_LITE_INTERNAL_INCLUDE_VECTOR_H #define MINDSPORE_LITE_INTERNAL_INCLUDE_VECTOR_H
#include <stdint.h> #include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stddef.h>
#include <initializer_list> #include <initializer_list>
#include "internal/include/string.h" #define DEFAULT_CAPACITY 4
#define DEFAULT_CAPACITY 1
struct MSTensor; struct MSTensor;
struct Node; struct Node;
template <typename T> template <typename T>
class Vector { class Vector {
private:
size_t size_;
size_t elem_size_;
size_t capacity_;
T *data_;
public: public:
Vector(); Vector();
explicit Vector(size_t size); explicit Vector(size_t size);
Vector(size_t size, const T &value);
Vector(const Vector<T> &vector); Vector(const Vector<T> &vector);
~Vector(); ~Vector();
@ -92,6 +90,12 @@ class Vector {
void reserve(size_t capacity); void reserve(size_t capacity);
Vector<T> &operator=(const Vector<T> &v); Vector<T> &operator=(const Vector<T> &v);
private:
size_t size_;
size_t elem_size_;
size_t capacity_;
T *data_;
}; };
template <typename T> template <typename T>

@ -14,11 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include "internal/include/vector.h" #include "internal/include/vector.h"
#include <stdlib.h> #include "internal/include/string.h"
#include <string.h>
#include "internal/src/lite_log.h" #include "internal/src/lite_log.h"
#define min(x, y) ((x < y) ? (x) : (y)) #define MIN(x, y) ((x < y) ? (x) : (y))
template <typename T> template <typename T>
Vector<T>::Vector() { Vector<T>::Vector() {
@ -40,6 +39,20 @@ Vector<T>::Vector(size_t size) {
memset(data_, 0, capacity_ * elem_size_); memset(data_, 0, capacity_ * elem_size_);
} }
template <typename T>
Vector<T>::Vector(size_t size, const T &value) {
size_ = size;
elem_size_ = sizeof(T);
capacity_ = size;
data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_));
if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
}
for (int i = 0; i < size; ++i) {
data_[i] = value;
}
}
template <typename T> template <typename T>
Vector<T>::Vector(const Vector<T> &vec) { Vector<T>::Vector(const Vector<T> &vec) {
size_ = vec.size_; size_ = vec.size_;
@ -95,14 +108,14 @@ void Vector<T>::push_back(const T &elem) {
resize(size_ + 1); resize(size_ + 1);
--size_; --size_;
} }
memcpy(data_ + size_, &elem, elem_size_); data_[size_] = elem;
++size_; ++size_;
} }
template <typename T> template <typename T>
void Vector<T>::push_back(T &&elem) { void Vector<T>::push_back(T &&elem) {
if (data_ == nullptr) { if (data_ == nullptr) {
data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_)); data_ = reinterpret_cast<T *>(malloc(elem_size_));
if (data_ == nullptr) { if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed"); MS_C_EXCEPTION("malloc data failed");
} }
@ -110,7 +123,7 @@ void Vector<T>::push_back(T &&elem) {
resize(size_ + 1); resize(size_ + 1);
--size_; --size_;
} }
memcpy(data_ + size_, &elem, elem_size_); data_[size_] = elem;
++size_; ++size_;
} }
@ -134,7 +147,7 @@ void Vector<T>::insert(const T &elem, size_t index) {
push_back(elem); push_back(elem);
} else { } else {
memmove(data_ + index + 1, data_ + index, (size_ - index - 1) * elem_size_); memmove(data_ + index + 1, data_ + index, (size_ - index - 1) * elem_size_);
memcpy(data_ + index, &elem, elem_size_); data_[index] = elem;
} }
} else { } else {
MS_C_EXCEPTION("Input index is out of range!"); MS_C_EXCEPTION("Input index is out of range!");
@ -164,7 +177,7 @@ const T *Vector<T>::end() const {
template <typename T> template <typename T>
T &Vector<T>::front() { T &Vector<T>::front() {
if (size_ > 0) { if (size_ > 0) {
return *data_; return data_[0];
} }
MS_C_EXCEPTION("Index is out of range!"); MS_C_EXCEPTION("Index is out of range!");
} }
@ -172,21 +185,21 @@ T &Vector<T>::front() {
template <typename T> template <typename T>
const T &Vector<T>::front() const { const T &Vector<T>::front() const {
if (size_ > 0) { if (size_ > 0) {
return *data_; return data_[0];
} }
MS_C_EXCEPTION("Index is out of range!"); MS_C_EXCEPTION("Index is out of range!");
} }
template <typename T> template <typename T>
T &Vector<T>::back() { T &Vector<T>::back() {
if (size_ > 0) { if (size_ > 0) {
return *(data_ + size_ - 1); return data_[size_ - 1];
} }
MS_C_EXCEPTION("Index is out of range!"); MS_C_EXCEPTION("Index is out of range!");
} }
template <typename T> template <typename T>
const T &Vector<T>::back() const { const T &Vector<T>::back() const {
if (size_ > 0) { if (size_ > 0) {
return *(data_ + size_ - 1); return data_[size_ - 1];
} }
MS_C_EXCEPTION("Index is out of range!"); MS_C_EXCEPTION("Index is out of range!");
} }
@ -194,7 +207,7 @@ const T &Vector<T>::back() const {
template <typename T> template <typename T>
T &Vector<T>::at(size_t index) { T &Vector<T>::at(size_t index) {
if (index < size_) { if (index < size_) {
return *(data_ + index); return data_[index];
} }
MS_C_EXCEPTION("Input index is out of range!"); MS_C_EXCEPTION("Input index is out of range!");
} }
@ -202,7 +215,7 @@ T &Vector<T>::at(size_t index) {
template <typename T> template <typename T>
const T &Vector<T>::at(size_t index) const { const T &Vector<T>::at(size_t index) const {
if (index < size_) { if (index < size_) {
return *(data_ + index); return data_[index];
} }
MS_C_EXCEPTION("Input index is out of range!"); MS_C_EXCEPTION("Input index is out of range!");
} }
@ -210,7 +223,7 @@ const T &Vector<T>::at(size_t index) const {
template <typename T> template <typename T>
T &Vector<T>::operator[](size_t index) { T &Vector<T>::operator[](size_t index) {
if (index < size_) { if (index < size_) {
return *(data_ + index); return data_[index];
} }
MS_C_EXCEPTION("Input index is out of range!"); MS_C_EXCEPTION("Input index is out of range!");
} }
@ -218,7 +231,7 @@ T &Vector<T>::operator[](size_t index) {
template <typename T> template <typename T>
const T &Vector<T>::operator[](size_t index) const { const T &Vector<T>::operator[](size_t index) const {
if (index < size_) { if (index < size_) {
return *(data_ + index); return data_[index];
} }
MS_C_EXCEPTION("Input index is out of range!"); MS_C_EXCEPTION("Input index is out of range!");
} }
@ -262,7 +275,7 @@ void Vector<T>::erase(size_t index) {
template <typename T> template <typename T>
void Vector<T>::resize(size_t size) { void Vector<T>::resize(size_t size) {
if (size > capacity_) { while (size > capacity_) {
capacity_ *= 2; capacity_ *= 2;
} }
T *tmp = data_; T *tmp = data_;
@ -270,7 +283,7 @@ void Vector<T>::resize(size_t size) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed"); MS_C_EXCEPTION("malloc data failed");
} }
memcpy(data_, tmp, min(size, size_) * elem_size_); memcpy(data_, tmp, MIN(size, size_) * elem_size_);
size_ = size; size_ = size;
free(tmp); free(tmp);
} }

@ -85,6 +85,7 @@ int DoMatMulInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector
int *in_shape[2] = {input0->shape_.data(), input1->shape_.data()}; int *in_shape[2] = {input0->shape_.data(), input1->shape_.data()};
int out_format; int out_format;
int out_datatype; int out_datatype;
output->shape_.resize(input0->shape_.size());
int ret = MatMulInferShape(in_shape, 2, dim_size, output->shape_.data(), in_format, &out_format, in_datatype, int ret = MatMulInferShape(in_shape, 2, dim_size, output->shape_.data(), in_format, &out_format, in_datatype,
&out_datatype, param); &out_datatype, param);
if (ret != NNACL_OK) { if (ret != NNACL_OK) {
@ -134,16 +135,16 @@ int DoMatMul(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tenso
LITE_LOG_ERROR("Malloc MatMulCPUKernelData failed"); LITE_LOG_ERROR("Malloc MatMulCPUKernelData failed");
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
kernel_data->a_c12_ptr_ kernel_data->a_c12_ptr_ =
= reinterpret_cast<float *>(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float))); reinterpret_cast<float *>(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float)));
if (kernel_data->a_c12_ptr_ == NULL) { if (kernel_data->a_c12_ptr_ == NULL) {
FreeMatMulKernelData(kernel_data, allocator); FreeMatMulKernelData(kernel_data, allocator);
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float)); memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float));
kernel_data->b_r8_ptr_ kernel_data->b_r8_ptr_ =
= reinterpret_cast<float *>(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float))); reinterpret_cast<float *>(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float)));
if (kernel_data->b_r8_ptr_ == NULL) { if (kernel_data->b_r8_ptr_ == NULL) {
FreeMatMulKernelData(kernel_data, allocator); FreeMatMulKernelData(kernel_data, allocator);
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
@ -173,4 +174,3 @@ int DoMatMul(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tenso
return RET_OK; return RET_OK;
} }

@ -23,13 +23,12 @@
#include <assert.h> #include <assert.h>
#endif #endif
#ifndef Release #ifdef Debug
#define LITE_DEBUG_LOG(format, ...) \ #define LITE_DEBUG_LOG(format, ...) \
printf("[DEBUG] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) printf("[DEBUG] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__)
#define LITE_INFO_LOG(format, ...) \ #define LITE_INFO_LOG(format, ...) \
printf("[INFO] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) printf("[INFO] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__)
#define LITE_LOG_INFO(...) \ #define LITE_LOG_INFO(...) printf("[INFO] [%s %s] [%s] [%d] %s\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__)
printf("[INFO] [%s %s] [%s] [%d] %s\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__)
#define LITE_WARNING_LOG(format, ...) \ #define LITE_WARNING_LOG(format, ...) \
printf("[WARNING] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) printf("[WARNING] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__)
#define LITE_ERROR_LOG(format, ...) \ #define LITE_ERROR_LOG(format, ...) \

Loading…
Cancel
Save