From dcea00321a82e78169ac6f5e3bcc31c497508eb7 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Tue, 22 Sep 2020 14:35:12 +0800 Subject: [PATCH] optimize custom vector --- mindspore/lite/internal/include/vector.h | 38 ++++--- mindspore/lite/internal/src/common/vector.cc | 107 ++++++++++-------- .../lite/internal/src/kernel/fp32/matmul.cc | 10 +- mindspore/lite/internal/src/lite_log.h | 5 +- 4 files changed, 88 insertions(+), 72 deletions(-) diff --git a/mindspore/lite/internal/include/vector.h b/mindspore/lite/internal/include/vector.h index 775400652c..17fd40cb48 100644 --- a/mindspore/lite/internal/include/vector.h +++ b/mindspore/lite/internal/include/vector.h @@ -17,26 +17,24 @@ #define MINDSPORE_LITE_INTERNAL_INCLUDE_VECTOR_H #include +#include +#include +#include #include -#include "internal/include/string.h" -#define DEFAULT_CAPACITY 1 +#define DEFAULT_CAPACITY 4 struct MSTensor; struct Node; template class Vector { - private: - size_t size_; - size_t elem_size_; - size_t capacity_; - T *data_; - public: Vector(); explicit Vector(size_t size); + Vector(size_t size, const T &value); + Vector(const Vector &vector); ~Vector(); @@ -92,23 +90,29 @@ class Vector { void reserve(size_t capacity); Vector &operator=(const Vector &v); + + private: + size_t size_; + size_t elem_size_; + size_t capacity_; + T *data_; }; template bool operator==(const Vector &lhs, const Vector &rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - for (int i = 0; i < lhs.size(); ++i) { - if (lhs[i] != rhs[i]) { - return false; - } + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i] != rhs[i]) { + return false; } - return true; + } + return true; } template bool operator!=(const Vector &lhs, const Vector &rhs) { - return !(lhs == rhs); + return !(lhs == rhs); } #endif // MINDSPORE_LITE_INTERNAL_INCLUDE_VECTOR_H diff --git a/mindspore/lite/internal/src/common/vector.cc b/mindspore/lite/internal/src/common/vector.cc index 5e9839145d..6d9201c6aa 100644 --- a/mindspore/lite/internal/src/common/vector.cc +++ b/mindspore/lite/internal/src/common/vector.cc @@ -14,13 +14,12 @@ * limitations under the License. */ #include "internal/include/vector.h" -#include -#include +#include "internal/include/string.h" #include "internal/src/lite_log.h" -#define min(x, y) ((x < y) ? (x) : (y)) +#define MIN(x, y) ((x < y) ? (x) : (y)) -template +template Vector::Vector() { size_ = 0; capacity_ = DEFAULT_CAPACITY; @@ -28,7 +27,7 @@ Vector::Vector() { data_ = nullptr; } -template +template Vector::Vector(size_t size) { size_ = size; elem_size_ = sizeof(T); @@ -40,7 +39,21 @@ Vector::Vector(size_t size) { memset(data_, 0, capacity_ * elem_size_); } -template +template +Vector::Vector(size_t size, const T &value) { + size_ = size; + elem_size_ = sizeof(T); + capacity_ = size; + data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); + if (data_ == nullptr) { + MS_C_EXCEPTION("malloc data failed"); + } + for (int i = 0; i < size; ++i) { + data_[i] = value; + } +} + +template Vector::Vector(const Vector &vec) { size_ = vec.size_; elem_size_ = sizeof(T); @@ -52,7 +65,7 @@ Vector::Vector(const Vector &vec) { memcpy(data_, vec.data_, size_ * elem_size_); } -template +template Vector &Vector::operator=(const Vector &vec) { if (this == &vec) { return *this; @@ -68,14 +81,14 @@ Vector &Vector::operator=(const Vector &vec) { return *this; } -template +template Vector::~Vector() { if (data_ != nullptr) { free(data_); } } -template +template void Vector::clear() { size_ = 0; if (data_ != nullptr) { @@ -84,7 +97,7 @@ void Vector::clear() { } } -template +template void Vector::push_back(const T &elem) { if (data_ == nullptr) { data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); @@ -95,14 +108,14 @@ void Vector::push_back(const T &elem) { resize(size_ + 1); --size_; } - memcpy(data_ + size_, &elem, elem_size_); + data_[size_] = elem; ++size_; } -template +template void Vector::push_back(T &&elem) { if (data_ == nullptr) { - data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); + data_ = reinterpret_cast(malloc(elem_size_)); if (data_ == nullptr) { MS_C_EXCEPTION("malloc data failed"); } @@ -110,11 +123,11 @@ void Vector::push_back(T &&elem) { resize(size_ + 1); --size_; } - memcpy(data_ + size_, &elem, elem_size_); + data_[size_] = elem; ++size_; } -template +template void Vector::pop_back() { if (size_ > 0) { --size_; @@ -123,7 +136,7 @@ void Vector::pop_back() { } } -template +template void Vector::insert(const T &elem, size_t index) { if (index <= size_) { ++size_; @@ -134,121 +147,121 @@ void Vector::insert(const T &elem, size_t index) { push_back(elem); } else { memmove(data_ + index + 1, data_ + index, (size_ - index - 1) * elem_size_); - memcpy(data_ + index, &elem, elem_size_); + data_[index] = elem; } } else { MS_C_EXCEPTION("Input index is out of range!"); } } -template +template T *Vector::begin() { return data_; } -template +template const T *Vector::begin() const { return data_; } -template +template T *Vector::end() { return data_ + size_; } -template +template const T *Vector::end() const { return data_ + size_; } -template +template T &Vector::front() { if (size_ > 0) { - return *data_; + return data_[0]; } MS_C_EXCEPTION("Index is out of range!"); } -template +template const T &Vector::front() const { if (size_ > 0) { - return *data_; + return data_[0]; } MS_C_EXCEPTION("Index is out of range!"); } -template +template T &Vector::back() { if (size_ > 0) { - return *(data_ + size_ - 1); + return data_[size_ - 1]; } MS_C_EXCEPTION("Index is out of range!"); } -template +template const T &Vector::back() const { if (size_ > 0) { - return *(data_ + size_ - 1); + return data_[size_ - 1]; } MS_C_EXCEPTION("Index is out of range!"); } -template +template T &Vector::at(size_t index) { if (index < size_) { - return *(data_ + index); + return data_[index]; } MS_C_EXCEPTION("Input index is out of range!"); } -template +template const T &Vector::at(size_t index) const { if (index < size_) { - return *(data_ + index); + return data_[index]; } MS_C_EXCEPTION("Input index is out of range!"); } -template +template T &Vector::operator[](size_t index) { if (index < size_) { - return *(data_ + index); + return data_[index]; } MS_C_EXCEPTION("Input index is out of range!"); } -template +template const T &Vector::operator[](size_t index) const { if (index < size_) { - return *(data_ + index); + return data_[index]; } MS_C_EXCEPTION("Input index is out of range!"); } -template +template T *Vector::data() { return data_; } -template +template const T *Vector::data() const { return data_; } -template +template size_t Vector::size() const { return size_; } -template +template size_t Vector::capacity() const { return capacity_; } -template +template bool Vector::empty() const { return size_ == 0; } -template +template void Vector::erase(size_t index) { if (index == size_ - 1) { --size_; @@ -260,9 +273,9 @@ void Vector::erase(size_t index) { } } -template +template void Vector::resize(size_t size) { - if (size > capacity_) { + while (size > capacity_) { capacity_ *= 2; } T *tmp = data_; @@ -270,12 +283,12 @@ void Vector::resize(size_t size) { if (data_ == nullptr) { MS_C_EXCEPTION("malloc data failed"); } - memcpy(data_, tmp, min(size, size_) * elem_size_); + memcpy(data_, tmp, MIN(size, size_) * elem_size_); size_ = size; free(tmp); } -template +template void Vector::reserve(size_t capacity) { if (capacity > capacity_) { capacity_ = capacity; diff --git a/mindspore/lite/internal/src/kernel/fp32/matmul.cc b/mindspore/lite/internal/src/kernel/fp32/matmul.cc index 9342fbe3b9..69530417bf 100644 --- a/mindspore/lite/internal/src/kernel/fp32/matmul.cc +++ b/mindspore/lite/internal/src/kernel/fp32/matmul.cc @@ -85,6 +85,7 @@ int DoMatMulInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector int *in_shape[2] = {input0->shape_.data(), input1->shape_.data()}; int out_format; int out_datatype; + output->shape_.resize(input0->shape_.size()); int ret = MatMulInferShape(in_shape, 2, dim_size, output->shape_.data(), in_format, &out_format, in_datatype, &out_datatype, param); if (ret != NNACL_OK) { @@ -134,16 +135,16 @@ int DoMatMul(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tenso LITE_LOG_ERROR("Malloc MatMulCPUKernelData failed"); return RET_MEMORY_FAILED; } - kernel_data->a_c12_ptr_ - = reinterpret_cast(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float))); + kernel_data->a_c12_ptr_ = + reinterpret_cast(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float))); if (kernel_data->a_c12_ptr_ == NULL) { FreeMatMulKernelData(kernel_data, allocator); return RET_MEMORY_FAILED; } memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float)); - kernel_data->b_r8_ptr_ - = reinterpret_cast(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float))); + kernel_data->b_r8_ptr_ = + reinterpret_cast(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float))); if (kernel_data->b_r8_ptr_ == NULL) { FreeMatMulKernelData(kernel_data, allocator); return RET_MEMORY_FAILED; @@ -173,4 +174,3 @@ int DoMatMul(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tenso return RET_OK; } - diff --git a/mindspore/lite/internal/src/lite_log.h b/mindspore/lite/internal/src/lite_log.h index ae42a87fe5..720e7eeb1a 100644 --- a/mindspore/lite/internal/src/lite_log.h +++ b/mindspore/lite/internal/src/lite_log.h @@ -23,13 +23,12 @@ #include #endif -#ifndef Release +#ifdef Debug #define LITE_DEBUG_LOG(format, ...) \ printf("[DEBUG] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) #define LITE_INFO_LOG(format, ...) \ printf("[INFO] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) -#define LITE_LOG_INFO(...) \ - printf("[INFO] [%s %s] [%s] [%d] %s\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) +#define LITE_LOG_INFO(...) printf("[INFO] [%s %s] [%s] [%d] %s\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) #define LITE_WARNING_LOG(format, ...) \ printf("[WARNING] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) #define LITE_ERROR_LOG(format, ...) \