|
|
|
@ -54,7 +54,7 @@ int Tensor::CopyTensorData(const Tensor &srcTensor) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
memcpy(this->data_, srcTensor.data_, data_size);
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Tensor::CopyTensor(const Tensor &srcTensor, bool copyData) {
|
|
|
|
@ -69,7 +69,7 @@ int Tensor::CopyTensor(const Tensor &srcTensor, bool copyData) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor::~Tensor() {
|
|
|
|
@ -102,7 +102,7 @@ bool Tensor::operator==(const Tensor &tensor) {
|
|
|
|
|
int32_t Tensor::Batch() const {
|
|
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
switch (this->format_) {
|
|
|
|
|
case schema::Format::Format_NHWC:
|
|
|
|
@ -123,14 +123,14 @@ int32_t Tensor::Batch() const {
|
|
|
|
|
return this->shape_[1];
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported format: " << EnumNameFormat(this->format_);
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t Tensor::Channel() const {
|
|
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
switch (this->format_) {
|
|
|
|
|
case schema::Format::Format_NCHW:
|
|
|
|
@ -150,14 +150,14 @@ int32_t Tensor::Channel() const {
|
|
|
|
|
case schema::Format::Format_CHWK:
|
|
|
|
|
return this->shape_[0];
|
|
|
|
|
default:
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t Tensor::Height() const {
|
|
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
switch (this->format_) {
|
|
|
|
|
case schema::Format::Format_NCHW:
|
|
|
|
@ -177,7 +177,7 @@ int32_t Tensor::Height() const {
|
|
|
|
|
return this->shape_[0];
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported format: " << EnumNameFormat(this->format_);
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -203,11 +203,28 @@ int32_t Tensor::Width() const {
|
|
|
|
|
case schema::Format::Format_HW4:
|
|
|
|
|
return this->shape_[1];
|
|
|
|
|
default:
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t Tensor::Size() const {
|
|
|
|
|
size_t size = DataTypeSize(this->data_type_);
|
|
|
|
|
size *= (format_ == schema::Format::Format_NC4HW4 || format_ == schema::Format::Format_NHWC4) ? ElementsC4Num()
|
|
|
|
|
: ElementsNum();
|
|
|
|
|
return size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Tensor::ElementsNum() const {
|
|
|
|
|
if (this->category_ == CONST_SCALAR) {
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t Tensor::ElementsC4Num() const {
|
|
|
|
|
if (this->category_ == CONST_SCALAR) {
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
int32_t result = 0;
|
|
|
|
|
if (this->shape_.size() == 4) {
|
|
|
|
|
result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4);
|
|
|
|
@ -217,6 +234,16 @@ int32_t Tensor::ElementsC4Num() const {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Tensor::DimensionSize(size_t index) const {
|
|
|
|
|
int dim_size = -1;
|
|
|
|
|
if (index < shape_.size()) {
|
|
|
|
|
dim_size = shape_[index];
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Dimension index is wrong: " << index;
|
|
|
|
|
}
|
|
|
|
|
return dim_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string Tensor::ToString() const {
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << "schema::Format: " << EnumNameFormat(this->format_);
|
|
|
|
@ -287,7 +314,7 @@ std::string Tensor::ToString() const {
|
|
|
|
|
|
|
|
|
|
int Tensor::MallocData(mindspore::lite::Allocator *allocator) {
|
|
|
|
|
if (nullptr != this->data_) {
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
if (allocator != nullptr) {
|
|
|
|
|
allocator_ = allocator;
|
|
|
|
@ -299,15 +326,15 @@ int Tensor::MallocData(mindspore::lite::Allocator *allocator) {
|
|
|
|
|
}
|
|
|
|
|
if (nullptr == this->data_) {
|
|
|
|
|
MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size();
|
|
|
|
|
return -1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Tensor::FreeData() {
|
|
|
|
|
if (nullptr == this->data_) {
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
if (nullptr == allocator_) {
|
|
|
|
|
free(this->data_);
|
|
|
|
@ -316,7 +343,7 @@ int Tensor::FreeData() {
|
|
|
|
|
allocator_->Free(this->data_);
|
|
|
|
|
this->data_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void *Tensor::MutableData() {
|
|
|
|
@ -330,6 +357,12 @@ void *Tensor::MutableData() {
|
|
|
|
|
return this->data_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Tensor::IsConst() {
|
|
|
|
|
return (this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR) && this->data_ != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Tensor::IsScalar() { return this->category_ == CONST_SCALAR && this->data_ != nullptr; }
|
|
|
|
|
|
|
|
|
|
void Tensor::AddQuantParam(const QuantArg &quant_arg) { this->quant_params_.push_back(quant_arg); }
|
|
|
|
|
|
|
|
|
|
std::vector<QuantArg> Tensor::GetQuantParams() const { return this->quant_params_; }
|
|
|
|
|