!6558 [MSLITE] Fix the bug of reading FC's tensor shape

Merge pull request !6558 from zhanyuan/dev
pull/6558/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d90377235c

@ -91,7 +91,7 @@ void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const floa
return; return;
} }
void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { void RowMajor2Col16MajorFp16Opt(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row_up_16 = UP_ROUND(row, C16NUM); size_t row_up_16 = UP_ROUND(row, C16NUM);
size_t row16 = row / C16NUM * C16NUM; size_t row16 = row / C16NUM * C16NUM;
size_t col8 = col / C8NUM * C8NUM; size_t col8 = col / C8NUM * C8NUM;
@ -245,52 +245,58 @@ void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row,
return; return;
} }
void Fp32RowMajor2Fp16Col16Major(float *src, float16_t *dst, size_t row, size_t col) { void RowMajor2Col16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) { for (int c = 0; c < col; c++) {
int r_div16 = r / 16; int r_div16 = r / 16;
int r_mod16 = r % 16; int r_mod16 = r % 16;
dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(src[r * col + c]); if (is_fp32_src) {
} dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(((float *)src)[r * col + c]);
} } else {
} dst[r_div16 * 16 * col + c * 16 + r_mod16] = ((float16_t *)src)[r * col + c];
}
void Fp32RowMajor2Fp16Row16Major(float *src, float16_t *dst, size_t row, size_t col) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int c_div16 = c / 16;
int c_mod16 = c % 16;
dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(src[r * col + c]);
} }
} }
} }
void Fp16RowMajor2Fp16Row16Major(float16_t *src, float16_t *dst, size_t row, size_t col) { void RowMajor2Row16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) { for (int c = 0; c < col; c++) {
int c_div16 = c / 16; int c_div16 = c / 16;
int c_mod16 = c % 16; int c_mod16 = c % 16;
dst[c_div16 * 16 * row + r * 16 + c_mod16] = src[r * col + c]; if (is_fp32_src) {
dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(((float *)src)[r * col + c]);
} else {
dst[c_div16 * 16 * row + r * 16 + c_mod16] = ((float16_t *)src)[r * col + c];
}
} }
} }
} }
void Fp32RowMajor2Fp16Row8Major(float *src, float16_t *dst, size_t row, size_t col) { void RowMajor2Row8MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) { for (int c = 0; c < col; c++) {
int c_div8 = c / 8; int c_div8 = c / 8;
int c_mod8 = c % 8; int c_mod8 = c % 8;
dst[c_div8 * 8 * row + r * 8 + c_mod8] = (float16_t)src[r * col + c]; if (is_fp32_src) {
dst[c_div8 * 8 * row + r * 8 + c_mod8] = (float16_t)(((float *)src)[r * col + c]);
} else {
dst[c_div8 * 8 * row + r * 8 + c_mod8] = ((float16_t *)src)[r * col + c];
}
} }
} }
} }
void Fp32RowMajor2Fp16Col8Major(float *src, float16_t *dst, size_t row, size_t col) { void RowMajor2Col8MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) { for (int c = 0; c < col; c++) {
int r_div8 = r / 8; int r_div8 = r / 8;
int r_mod8 = r % 8; int r_mod8 = r % 8;
dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)src[r * col + c]; if (is_fp32_src) {
dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)(((float *)src)[r * col + c]);
} else {
dst[r_div8 * 8 * col + c * 8 + r_mod8] = ((float16_t *)src)[r * col + c];
}
} }
} }
} }

@ -34,20 +34,18 @@ void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const floa
void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16); void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16);
void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); void RowMajor2Col16MajorFp16Opt(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
void Fp32RowMajor2Fp16Col16Major(float *src, float16_t *dst, size_t row, size_t col); void RowMajor2Col16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src);
void Fp32RowMajor2Fp16Row16Major(float *src, float16_t *dst, size_t row, size_t col); void RowMajor2Row16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src);
void Fp16RowMajor2Fp16Row16Major(float16_t *src, float16_t *dst, size_t row, size_t col); void RowMajor2Row8MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src);
void Fp32RowMajor2Fp16Row8Major(float *src, float16_t *dst, size_t row, size_t col); void RowMajor2Col8MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src);
void Fp32RowMajor2Fp16Col8Major(float *src, float16_t *dst, size_t row, size_t col);
#ifdef __cplusplus #ifdef __cplusplus
} }

@ -161,7 +161,7 @@ void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *s
input_ptr_ = src_input; input_ptr_ = src_input;
} }
RowMajor2Col16MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
return; return;
} }

@ -186,7 +186,7 @@ int DeConvolutionFp16CPUKernel::Run() {
} }
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
RowMajor2Col16MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_); RowMajor2Col16MajorFp16Opt(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_);
error_code = ParallelLaunch(this->context_->thread_pool_, DeConvFp16Run, this, thread_count_); error_code = ParallelLaunch(this->context_->thread_pool_, DeConvFp16Run, this, thread_count_);
if (error_code != RET_OK) { if (error_code != RET_OK) {

@ -52,8 +52,10 @@ void FullconnectionFP16CPUKernel::FreeTmpBuffer() {
int FullconnectionFP16CPUKernel::ReSize() { int FullconnectionFP16CPUKernel::ReSize() {
FreeTmpBuffer(); FreeTmpBuffer();
fc_param_->row_ = (in_tensors_[0]->shape())[0]; int row = 1;
fc_param_->col_ = (in_tensors_[1]->shape())[0]; for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) row *= (out_tensors_[0]->shape())[i];
fc_param_->row_ = row;
fc_param_->col_ = out_tensors_[0]->shape().back();
fc_param_->deep_ = (in_tensors_[1]->shape())[1]; fc_param_->deep_ = (in_tensors_[1]->shape())[1];
fc_param_->row_16_ = UP_ROUND(fc_param_->row_, C16NUM); fc_param_->row_16_ = UP_ROUND(fc_param_->row_, C16NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
@ -76,7 +78,7 @@ int FullconnectionFP16CPUKernel::ReSize() {
} }
memset(b_pack_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float16_t)); memset(b_pack_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float16_t));
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->MutableData()), b_pack_ptr_); InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
if (in_tensors_.size() == 3) { if (in_tensors_.size() == 3) {
bias_ptr_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * sizeof(float16_t))); bias_ptr_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * sizeof(float16_t)));
if (bias_ptr_ == nullptr) { if (bias_ptr_ == nullptr) {
@ -84,7 +86,7 @@ int FullconnectionFP16CPUKernel::ReSize() {
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
memset(bias_ptr_, 0, fc_param_->col_8_ * sizeof(float16_t)); memset(bias_ptr_, 0, fc_param_->col_8_ * sizeof(float16_t));
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->MutableData()), bias_ptr_, fc_param_->col_); Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->data_c()), bias_ptr_, fc_param_->col_);
} }
if (out_tensors_[0]->data_type() == kNumberTypeFloat32) { if (out_tensors_[0]->data_type() == kNumberTypeFloat32) {
@ -95,15 +97,15 @@ int FullconnectionFP16CPUKernel::ReSize() {
} }
void FullconnectionFP16CPUKernel::InitMatrixA(float *a_ptr, float16_t *a_pack_ptr) { void FullconnectionFP16CPUKernel::InitMatrixA(float *a_ptr, float16_t *a_pack_ptr) {
Fp32RowMajor2Fp16Col16Major(a_ptr, a_pack_ptr, fc_param_->row_, fc_param_->deep_); RowMajor2Col16MajorFp16(reinterpret_cast<void *>(a_ptr), a_pack_ptr, fc_param_->row_, fc_param_->deep_, true);
} }
void FullconnectionFP16CPUKernel::InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr) { void FullconnectionFP16CPUKernel::InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr) {
RowMajor2Col16MajorFp16(a_ptr, a_pack_ptr, fc_param_->row_, fc_param_->deep_); RowMajor2Col16MajorFp16(reinterpret_cast<void *>(a_ptr), a_pack_ptr, fc_param_->row_, fc_param_->deep_, false);
} }
void FullconnectionFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_ptr) { void FullconnectionFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_ptr) {
Fp32RowMajor2Fp16Col8Major(b_ptr, b_pack_ptr, fc_param_->col_, fc_param_->deep_); RowMajor2Col8MajorFp16(reinterpret_cast<void *>(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, true);
} }
int FullconnectionFP16CPUKernel::Init() { int FullconnectionFP16CPUKernel::Init() {
@ -147,17 +149,17 @@ int FullconnectionFP16CPUKernel::Run() {
if (out_tensor->data_type() == kNumberTypeFloat32) { if (out_tensor->data_type() == kNumberTypeFloat32) {
output_ptr_ = output_fp16_; output_ptr_ = output_fp16_;
} else { } else {
output_ptr_ = reinterpret_cast<float16_t *>(out_tensor->MutableData()); output_ptr_ = reinterpret_cast<float16_t *>(out_tensor->data_c());
} }
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->MutableData()), a_pack_ptr_); InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()), a_pack_ptr_);
} else { } else {
InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->MutableData()), a_pack_ptr_); InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()), a_pack_ptr_);
} }
ParallelLaunch(this->context_->thread_pool_, FcFP16Run, this, thread_count_); ParallelLaunch(this->context_->thread_pool_, FcFP16Run, this, thread_count_);
if (out_tensor->data_type() == kNumberTypeFloat32) { if (out_tensor->data_type() == kNumberTypeFloat32) {
auto size = out_tensor->ElementsNum(); auto size = out_tensor->ElementsNum();
auto out_tensor_data = reinterpret_cast<float *>(out_tensor->MutableData()); auto out_tensor_data = reinterpret_cast<float *>(out_tensor->data_c());
Float16ToFloat32(output_fp16_, out_tensor_data, size); Float16ToFloat32(output_fp16_, out_tensor_data, size);
} }
return RET_OK; return RET_OK;

@ -91,17 +91,21 @@ int MatmulFP16CPUKernel::ReSize() {
} }
memset(b_pack_ptr_, 0, params_->batch * params_->col_8_ * params_->deep_ * sizeof(float16_t)); memset(b_pack_ptr_, 0, params_->batch * params_->col_8_ * params_->deep_ * sizeof(float16_t));
params_->a_const_ = (in_tensors_[0]->MutableData() != nullptr); params_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
params_->b_const_ = (in_tensors_[1]->MutableData() != nullptr); params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
if (params_->a_const_ == true) { if (params_->a_const_ == true) {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->MutableData()), a_pack_ptr_); InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()), a_pack_ptr_);
} else { } else {
InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->MutableData()), a_pack_ptr_); InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()), a_pack_ptr_);
} }
} }
if (params_->b_const_ == true) { if (params_->b_const_ == true) {
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->MutableData()), b_pack_ptr_); if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
} else {
InitMatrixB(reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()), b_pack_ptr_);
}
} }
if (in_tensors_.size() == 3) { if (in_tensors_.size() == 3) {
@ -111,7 +115,7 @@ int MatmulFP16CPUKernel::ReSize() {
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float16_t)); memset(bias_ptr_, 0, params_->col_8_ * sizeof(float16_t));
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->MutableData()), bias_ptr_, params_->col_); Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->data_c()), bias_ptr_, params_->col_);
} }
if (out_tensors_[0]->data_type() == kNumberTypeFloat32) { if (out_tensors_[0]->data_type() == kNumberTypeFloat32) {
@ -130,9 +134,9 @@ void MatmulFP16CPUKernel::InitMatrixA(float *a_ptr, float16_t *a_pack_ptr) {
float *src = a_ptr + i * params_->deep_ * params_->row_; float *src = a_ptr + i * params_->deep_ * params_->row_;
float16_t *dst = a_pack_ptr + i * params_->deep_ * params_->row_16_; float16_t *dst = a_pack_ptr + i * params_->deep_ * params_->row_16_;
if (params_->a_transpose_) { if (params_->a_transpose_) {
Fp32RowMajor2Fp16Row16Major(src, dst, params_->deep_, params_->row_); RowMajor2Row16MajorFp16(reinterpret_cast<void *>(src), dst, params_->deep_, params_->row_, true);
} else { } else {
Fp32RowMajor2Fp16Col16Major(src, dst, params_->row_, params_->deep_); RowMajor2Col16MajorFp16(reinterpret_cast<void *>(src), dst, params_->row_, params_->deep_, true);
} }
} }
} }
@ -142,9 +146,9 @@ void MatmulFP16CPUKernel::InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr) {
float16_t *src = a_ptr + i * params_->deep_ * params_->row_; float16_t *src = a_ptr + i * params_->deep_ * params_->row_;
float16_t *dst = a_pack_ptr + i * params_->deep_ * params_->row_16_; float16_t *dst = a_pack_ptr + i * params_->deep_ * params_->row_16_;
if (params_->a_transpose_) { if (params_->a_transpose_) {
Fp16RowMajor2Fp16Row16Major(src, dst, params_->deep_, params_->row_); RowMajor2Row16MajorFp16(reinterpret_cast<void *>(src), dst, params_->deep_, params_->row_, false);
} else { } else {
RowMajor2Col16MajorFp16(src, dst, params_->row_, params_->deep_); RowMajor2Col16MajorFp16(reinterpret_cast<void *>(src), dst, params_->row_, params_->deep_, false);
} }
} }
} }
@ -154,9 +158,21 @@ void MatmulFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_ptr) {
float *src = b_ptr + i * params_->deep_ * params_->col_; float *src = b_ptr + i * params_->deep_ * params_->col_;
float16_t *dst = b_pack_ptr + i * params_->deep_ * params_->col_8_; float16_t *dst = b_pack_ptr + i * params_->deep_ * params_->col_8_;
if (params_->b_transpose_) { if (params_->b_transpose_) {
Fp32RowMajor2Fp16Col8Major(src, dst, params_->col_, params_->deep_); RowMajor2Col8MajorFp16(reinterpret_cast<void *>(src), dst, params_->col_, params_->deep_, true);
} else {
RowMajor2Row8MajorFp16(reinterpret_cast<void *>(src), dst, params_->deep_, params_->col_, true);
}
}
}
void MatmulFP16CPUKernel::InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr) {
for (int i = 0; i < params_->batch; i++) {
float16_t *src = b_ptr + i * params_->deep_ * params_->col_;
float16_t *dst = b_pack_ptr + i * params_->deep_ * params_->col_8_;
if (params_->b_transpose_) {
RowMajor2Col8MajorFp16(reinterpret_cast<void *>(src), dst, params_->col_, params_->deep_, false);
} else { } else {
Fp32RowMajor2Fp16Row8Major(src, dst, params_->deep_, params_->col_); RowMajor2Row8MajorFp16(reinterpret_cast<void *>(src), dst, params_->deep_, params_->col_, false);
} }
} }
} }
@ -198,23 +214,26 @@ int MatmulFP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret; return prepare_ret;
} }
auto b = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
auto out_tensor = out_tensors_[0]; auto out_tensor = out_tensors_[0];
float16_t *c_ptr = nullptr; float16_t *c_ptr = nullptr;
if (out_tensor->data_type() == kNumberTypeFloat32) { if (out_tensor->data_type() == kNumberTypeFloat32) {
c_ptr = output_ptr_; c_ptr = output_ptr_;
} else { } else {
c_ptr = reinterpret_cast<float16_t *>(out_tensor->MutableData()); c_ptr = reinterpret_cast<float16_t *>(out_tensor->data_c());
} }
if (params_->a_const_ == false) { if (params_->a_const_ == false) {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->MutableData()), a_pack_ptr_); InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()), a_pack_ptr_);
} else { } else {
InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->MutableData()), a_pack_ptr_); InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()), a_pack_ptr_);
} }
} }
if (params_->b_const_ == false) { if (params_->b_const_ == false) {
InitMatrixB(b, b_pack_ptr_); if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
} else {
InitMatrixB(reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()), b_pack_ptr_);
}
} }
for (int i = 0; i < params_->batch; ++i) { for (int i = 0; i < params_->batch; ++i) {
current_a_ = a_pack_ptr_ + i * params_->row_16_ * params_->deep_; current_a_ = a_pack_ptr_ + i * params_->row_16_ * params_->deep_;
@ -224,7 +243,7 @@ int MatmulFP16CPUKernel::Run() {
} }
if (out_tensor->data_type() == kNumberTypeFloat32) { if (out_tensor->data_type() == kNumberTypeFloat32) {
auto size = out_tensor->ElementsNum(); auto size = out_tensor->ElementsNum();
auto out_tensor_data = reinterpret_cast<float *>(out_tensor->MutableData()); auto out_tensor_data = reinterpret_cast<float *>(out_tensor->data_c());
Float16ToFloat32(output_ptr_, out_tensor_data, size); Float16ToFloat32(output_ptr_, out_tensor_data, size);
} }
return RET_OK; return RET_OK;

@ -42,6 +42,7 @@ class MatmulFP16CPUKernel : public MatmulBaseCPUKernel {
void InitMatrixA(float *a_ptr, float16_t *a_pack_ptr); void InitMatrixA(float *a_ptr, float16_t *a_pack_ptr);
void InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr); void InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr);
void InitMatrixB(float *b_ptr, float16_t *b_pack_ptr); void InitMatrixB(float *b_ptr, float16_t *b_pack_ptr);
void InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr);
void FreeTmpBuffer(); void FreeTmpBuffer();
private: private:

@ -43,8 +43,10 @@ void FullconnectionCPUKernel::FreeBuf() {
int FullconnectionCPUKernel::ReSize() { int FullconnectionCPUKernel::ReSize() {
FreeBuf(); FreeBuf();
fc_param_->row_ = (in_tensors_[0]->shape())[0]; int row = 1;
fc_param_->col_ = (in_tensors_[1]->shape())[0]; for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) row *= (out_tensors_[0]->shape())[i];
fc_param_->row_ = row;
fc_param_->col_ = out_tensors_[0]->shape().back();
fc_param_->deep_ = (in_tensors_[1]->shape())[1]; fc_param_->deep_ = (in_tensors_[1]->shape())[1];
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);

@ -33,8 +33,10 @@ int FullconnectionInt8CPUKernel::Init() {
int FullconnectionInt8CPUKernel::ReSize() { int FullconnectionInt8CPUKernel::ReSize() {
FreeTmpBuffer(); FreeTmpBuffer();
fc_param_->row_ = (in_tensors_[0]->shape())[0]; int row = 1;
fc_param_->col_ = (in_tensors_[1]->shape())[0]; for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) row *= (out_tensors_[0]->shape())[i];
fc_param_->row_ = row;
fc_param_->col_ = out_tensors_[0]->shape().back();
fc_param_->deep_ = (in_tensors_[1]->shape())[1]; fc_param_->deep_ = (in_tensors_[1]->shape())[1];
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);

Loading…
Cancel
Save