!5006 [MS][LITE][Develop]Deconv Matmul 12x8

Merge pull request !5006 from ling/deconv
pull/5006/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 33a562de3d

@ -251,12 +251,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
}
}
// fp32 conv1x1 strassen matmul
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
StrassenMatMulParameter matmul_param) {
return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr);
}
// fp32 conv winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,

@ -24,7 +24,6 @@
#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/strassen_matmul.h"
#include "nnacl/winograd_utils.h"
#include "nnacl/fp32/conv_depthwise.h"
@ -52,10 +51,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
GEMM_FUNC_FP32 gemm_func);
// fp32 conv1x1 strassen matmul
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
StrassenMatMulParameter matmul_param);
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,

@ -33,18 +33,18 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in
return;
}
int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_ROUND(output_channel, C8NUM);
int in_plane8 = UP_ROUND(input_plane, C8NUM);
int in_plane12 = UP_ROUND(input_plane, C12NUM);
int src_iw_stride = C8NUM;
int src_ih_stride = conv_param->input_w_ * C8NUM;
int src_kw_stride = in_plane8 * C8NUM;
int src_kh_stride = in_plane8 * conv_param->kernel_w_ * C8NUM;
int src_kw_stride = in_plane12 * C8NUM;
int src_kh_stride = in_plane12 * conv_param->kernel_w_ * C8NUM;
int dst_oh_stride = conv_param->output_w_ * C8NUM;
int dst_ow_stride = C8NUM;
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM;
@ -52,7 +52,7 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
for (int c = 0; c < oc8; c += 8) {
float *dst_ptr = tmp + c * output_plane;
const float *src_ptr = src + c * in_plane8 * kernel_plane;
const float *src_ptr = src + c * in_plane12 * kernel_plane;
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
@ -101,41 +101,3 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
conv_param->is_relu6_);
return NNACL_OK;
}
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) {
int oc4 = UP_DIV(output_channel, C4NUM);
for (int c = 0; c < oc4; c++) {
float *dst_ptr = tmp_c4 + c * output_plane * C4NUM;
const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM;
memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
for (int kh = kh_start; kh < kh_end; kh++) {
for (int kw = kw_start; kw < kw_end; kw++) {
int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM +
kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM;
int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM +
kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM +
kw * conv_param->dilation_w_ * C4NUM;
for (int i = 0; i < C4NUM; i++) {
dst_ptr[dst_index + i] += src_ptr[src_index + i];
}
} /*kw*/
} /*kh*/
} /*iw*/
} /*ih*/
} /*oc4*/
PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
return NNACL_OK;
}

@ -16,20 +16,19 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_DECONV_H_
#define MINDSPORE_LITE_NNACL_FP32_DECONV_H_
#include <string.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/strassen_matmul.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/common_func.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane);
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param);
int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
int DeConvPostFp32C12x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

@ -28,6 +28,18 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
return;
}
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / C12NUM;
int cm8 = c % C12NUM;
dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
}
}
return;
}
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row12 = row / C12NUM * C12NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -323,18 +335,18 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
return;
}
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
if (write_nhwc) {
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
if (out_type == OutType_Nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r8div = r / 8, r8mod = r % 8;
int r12div = r / 12, r12mod = r % 12;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
@ -345,18 +357,20 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
}
}
} else {
/* col8-major * row8-major => col8x8-major */
/* col8-major * row8-major => col12x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_8 = UP_ROUND(row, C8NUM);
for (int r = 0; r < row_8; r++) {
int row_12 = UP_ROUND(row, C12NUM);
for (int r = 0; r < row_12; r++) {
for (int c = 0; c < col_8; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
int r12div = r / C12NUM, r12mod = r % C12NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = (out_type == OutType_C4) ? (c4div * C4NUM * row_12 + r * C4NUM + c4mod)
: (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
@ -369,45 +383,12 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
return;
}
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4) {
if (writeNhwc != 0) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r12div = r / 12, r12mod = r % 12;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
#endif
}
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4) {
int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, writeNhwc, writeC4);
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_C4));
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, writeNhwc, writeC4);
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif
}

@ -26,11 +26,11 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
int stride, bool write_nhwc);
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4);
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
@ -38,7 +38,7 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4);
int col, size_t stride, size_t write_nhwc, size_t write_c4);
#endif
#ifdef __cplusplus
}

@ -1,204 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/strassen_matmul.h"
bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) {
if (cur_recursion >= max_recursion) {
return false;
}
if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) {
return false;
}
int row2 = row / 2;
int col2 = col / 2;
int deep2 = deep / 2;
float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 -
7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) -
4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3);
return (save_cost > 0.f);
}
void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
int c_stride) {
int row4mod = row % 4;
int row4div = row / 4;
for (int r = 0; r < row; r++) {
int r4mod = r % 4;
int r4div = r / 4;
for (int c = 0; c < col * 4; c++) {
float value = 0;
int ic = c / 4 * c_stride + r * 4 + c % 4;
for (int d = 0; d < deep * 4; d++) {
int d4mod = d % 4;
int d4div = d / 4;
int a_stride = (r < (row4div * 4)) ? 4 : row4mod;
int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod;
int bi = c / 4 * b_stride + d * 4 + c % 4;
value = value + a_ptr[ai] * b_ptr[bi];
}
dst_ptr[ic] = value;
}
}
return;
}
void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
int c_stride) {
int row4mod = row % 4;
int row4div = row / 4;
if (row4div > 0) {
GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride);
}
if (row4mod != 0) {
GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride,
c_stride);
}
return;
}
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
size_t row2 = matmul_param->row_ / 2;
size_t deep2 = matmul_param->deep_ / 2;
size_t col2 = matmul_param->col_ / 2;
size_t a_stride = matmul_param->a_stride_;
size_t b_stride = matmul_param->b_stride_;
size_t c_stride = matmul_param->c_stride_;
StrassenMatMulParameter rec_matmul;
rec_matmul.row_ = row2;
rec_matmul.deep_ = deep2;
rec_matmul.col_ = col2;
float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float)));
if (x_ptr == NULL) {
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
}
float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float)));
if (y_ptr == NULL) {
free(x_ptr);
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
}
size_t x_stride = row2 * FP32_STRASSEN_UINT;
size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT;
const float *a11 = a_ptr;
const float *a12 = a_ptr + deep2 * a_stride;
const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT;
const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT;
const float *b11 = b_ptr;
const float *b12 = b_ptr + col2 * b_stride;
const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT;
const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT;
float *c11 = c_ptr;
float *c12 = c_ptr + col2 * c_stride;
float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT;
float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT;
/* S3 = A11 - A21 */
MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
/* T3 = B22 - B12 */
MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
/* P7 = S3T3 */
rec_matmul.a_stride_ = x_stride;
rec_matmul.b_stride_ = y_stride;
rec_matmul.c_stride_ = c_stride;
StrassenMatmul(x_ptr, y_ptr, c21, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S1 = A21 + A22 */
MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
/* T1 = B12 - B11 */
MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
/* P5 = S1T1 */
StrassenMatmul(x_ptr, y_ptr, c22, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S2 = S1 - A11 */
MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2);
/* T2 = B22 - T1 */
MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2);
/* P6 = S2T2 */
StrassenMatmul(x_ptr, y_ptr, c12, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S4 = A12 - S2 */
MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2);
/* P3 = S4B22 */
rec_matmul.b_stride_ = b_stride;
StrassenMatmul(x_ptr, b22, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* P1 = A11B11 */
rec_matmul.a_stride_ = a_stride;
rec_matmul.c_stride_ = row2 * FP32_STRASSEN_UINT;
StrassenMatmul(a11, b11, x_ptr, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U2 = P1 + P6
U3 = U2 + P7
U4 = U2 + P5
U7 = U3 + P5
U5 = U4 + P3 */
MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride);
/* T4 = T2 - B21 */
MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2);
/* P4 = A22T4 */
rec_matmul.b_stride_ = y_stride;
rec_matmul.c_stride_ = c_stride;
StrassenMatmul(a22, y_ptr, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U6 = U3 - P4 */
MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2);
/* P2 = A12B21 */
rec_matmul.b_stride_ = b_stride;
StrassenMatmul(a12, b21, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U1 = P1 + P2 */
MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2);
free(x_ptr);
free(y_ptr);
return NNACL_OK;
}
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
float *tmp_a_ptr) {
MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_);
GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_,
matmul_param->b_stride_, matmul_param->c_stride_);
return NNACL_OK;
}
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) {
return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr);
}
return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr);
}

@ -1,45 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
#define MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
#include <memory.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/strassen_matmul.h"
#include "nnacl/fp32/common_func.h"
#define FP32_STRASSEN_UINT C4NUM
#define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM)
#define FP32_STRASSEN_MAX_RECURSION 5
#ifdef __cplusplus
extern "C" {
#endif
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int, float *tmp_a_ptr);
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param,
float *tmp_a_ptr);
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_

@ -31,6 +31,8 @@ typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_C4 = 2 } OutType;
typedef struct MatMulParameter {
OpParameter op_parameter_;
int row_;

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#define MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#include "nnacl/op_base.h"
/* hw*inc4 X inc4*oc4 */
typedef struct StrassenMatMulParameter {
OpParameter op_parameter;
int row_; /* h * w */
int col_; /* oc4 / 4 */
int deep_; /* inc4 / 4 */
int a_stride_; /* h * w * 4 */
int b_stride_; /* inc4 * 4 */
int c_stride_; /* h * w * 4 */
} StrassenMatMulParameter;
#endif // MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_

@ -39,6 +39,10 @@ void Convolution1x1CPUKernel::FreeTmpBuffer() {
free(pack_input_);
pack_input_ = nullptr;
}
if (pre_trans_input_ && input_ptr_ != nullptr) {
free(input_ptr_);
input_ptr_ = nullptr;
}
return;
}
@ -106,6 +110,16 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
return RET_MEMORY_FAILED;
}
memset(pack_input_, 0, matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float));
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)));
if (input_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!";
return RET_MEMORY_FAILED;
}
memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float));
}
return RET_OK;
}
@ -140,13 +154,10 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, 1, 0);
output_ptr_ + task_id * thread_stride_, reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id,
matmul_param_->act_type_, matmul_param_->deep_, matmul_param_->row_, cur_oc, matmul_param_->col_,
OutType_Nhwc);
return RET_OK;
}
@ -169,15 +180,6 @@ int Convolution1x1CPUKernel::Run() {
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->Data());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->Data());
if (pre_trans_input_) {
input_ptr_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)));
if (input_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!";
return RET_MEMORY_FAILED;
}
}
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_,
src_out + batch_index * matmul_param_->row_ * matmul_param_->col_);
@ -189,10 +191,6 @@ int Convolution1x1CPUKernel::Run() {
}
}
if (pre_trans_input_) {
ctx_->allocator->Free(input_ptr_);
input_ptr_ = nullptr;
}
return RET_OK;
}
} // namespace mindspore::kernel

@ -95,13 +95,13 @@ int DeConvolutionCPUKernel::InitParam() {
matmul_param_->row_ = input_plane_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM);
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_);
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)));
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
return RET_ERROR;
@ -126,14 +126,14 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
return RET_OK;
}
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_;
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer,
nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, false);
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
DeConvPostFp32C8x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
return RET_OK;
}
@ -165,7 +165,7 @@ int DeConvolutionCPUKernel::InitRunBuf() {
}
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float)));
if (tmp_buffer_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
return RET_NULL_PTR;
@ -192,7 +192,7 @@ int DeConvolutionCPUKernel::Run() {
input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_;
output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_;
RowMajor2Col8Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_);
RowMajor2Col12Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_);
error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_count_);
if (error_code != RET_OK) {

@ -27,18 +27,14 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() {
}
void FullconnectionCPUKernel::FreeBuf() {
if (a_c8_ptr_ != nullptr) {
free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
if (a_c12_ptr_ != nullptr) {
free(a_c12_ptr_);
a_c12_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
@ -51,8 +47,8 @@ int FullconnectionCPUKernel::ReSize() {
fc_param_->col_ = (in_tensors_[1]->shape())[0];
fc_param_->deep_ = (in_tensors_[1]->shape())[1];
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
@ -63,11 +59,11 @@ int FullconnectionCPUKernel::ReSize() {
memcpy(bias_ptr_, in_tensors_[2]->Data(), fc_param_->col_ * sizeof(float));
}
a_c8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(float)));
if (a_c8_ptr_ == nullptr) {
a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(float));
memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float));
b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
@ -76,16 +72,9 @@ int FullconnectionCPUKernel::ReSize() {
}
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float));
c_r8x8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)));
if (c_r8x8_ptr_ == nullptr) {
FreeBuf();
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float));
fc_param_->a_const_ = false;
fc_param_->b_const_ = false;
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_);
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_);
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
return RET_OK;
}
@ -105,7 +94,7 @@ void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
return;
}
fc_param_->a_const_ = true;
RowMajor2Col8Major(src_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
return;
}
@ -132,15 +121,14 @@ int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
int FullconnectionCPUKernel::DoMatmul(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM);
if (cur_oc <= 0) {
return RET_OK;
}
MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_,
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_,
cur_oc * 8, 0, false);
MatMulOpt(a_c12_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r_ptr + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM,
fc_param_->act_type_, fc_param_->deep_, fc_param_->row_, cur_oc, fc_param_->col_, OutType_Nhwc);
return RET_OK;
}
@ -152,14 +140,13 @@ int FullconnectionCPUKernel::Run() {
}
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
c_r_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
InitMatrixA(a_ptr, a_c8_ptr_);
InitMatrixA(a_ptr, a_c12_ptr_);
InitMatrixB(b_ptr, b_r8_ptr_);
LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_);
Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_, fc_param_->col_);
return RET_OK;
}
} // namespace mindspore::kernel

@ -47,9 +47,9 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel {
void InitMatrixB(float *src_ptr, float *dst_ptr);
private:
float *a_c8_ptr_ = nullptr;
float *a_c12_ptr_ = nullptr;
float *b_r8_ptr_ = nullptr;
float *c_r8x8_ptr_ = nullptr;
float *c_r_ptr = nullptr;
float *bias_ptr_ = nullptr;
};
} // namespace mindspore::kernel

@ -28,18 +28,14 @@ namespace mindspore::kernel {
MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); }
void MatmulCPUKernel::FreeTmpBuffer() {
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
if (a_c12_ptr_ != nullptr) {
ctx_->allocator->Free(a_c12_ptr_);
a_c12_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
ctx_->allocator->Free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
ctx_->allocator->Free(bias_ptr_);
bias_ptr_ = nullptr;
@ -66,45 +62,37 @@ int MatmulCPUKernel::ReSize() {
params_->row_ = c_shape[c_shape.size() - 2];
params_->col_ = c_shape[c_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
a_c8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(float)));
if (a_c8_ptr_ == nullptr) {
a_c12_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_12_ * params_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(float));
memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float));
b_r8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(float));
c_r8x8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(float)));
if (c_r8x8_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float));
params_->a_const_ = false;
params_->b_const_ = false;
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_);
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_);
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
if (bias_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
if (in_tensors_.size() == 3) {
bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
if (bias_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
memcpy(bias_ptr_, in_tensors_[2]->Data(), params_->col_ * sizeof(float));
} else {
bias_ptr_ = nullptr;
}
return RET_OK;
@ -120,9 +108,9 @@ void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
params_->a_const_ = true;
if (params_->a_transpose_) {
RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->row_);
RowMajor2Row12Major(src_ptr, dst_ptr, params_->deep_, params_->row_);
} else {
RowMajor2Col8Major(src_ptr, a_c8_ptr_, params_->row_, params_->deep_);
RowMajor2Col12Major(src_ptr, dst_ptr, params_->row_, params_->deep_);
}
return;
}
@ -152,18 +140,13 @@ int MatmulCPUKernel::Init() {
}
int MatmulCPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
if (bias_ptr_) {
auto cur_bias = bias_ptr_ + task_id * thread_stride_ * C8NUM;
MatMul(a_c8_ptr_, cur_b, cur_c, cur_bias, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
} else {
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
}
MatMulOpt(a_c12_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_,
c_r_ptr_ + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM, ActType_No,
params_->deep_, params_->row_, cur_oc, params_->col_, OutType_Nhwc);
return RET_OK;
}
@ -192,13 +175,12 @@ int MatmulCPUKernel::Run() {
for (int i = 0; i < params_->batch; ++i) {
auto cur_a_ptr = a_ptr + i * a_stride;
auto cur_b_ptr = b_ptr + i * b_stride;
auto cur_c_ptr = c_ptr + i * c_stride;
c_r_ptr_ = c_ptr + i * c_stride;
InitMatrixA(cur_a_ptr, a_c8_ptr_);
InitMatrixA(cur_a_ptr, a_c12_ptr_);
InitMatrixB(cur_b_ptr, b_r8_ptr_);
LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_);
Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_);
}
return RET_OK;
}

@ -41,9 +41,9 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel {
void FreeTmpBuffer();
private:
float *a_c8_ptr_ = nullptr;
float *a_c12_ptr_ = nullptr;
float *b_r8_ptr_ = nullptr;
float *c_r8x8_ptr_ = nullptr;
float *c_r_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
};
} // namespace mindspore::kernel

@ -19,9 +19,8 @@
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/strassen_matmul.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
namespace mindspore {
using mindspore::lite::tensor::Tensor;

@ -548,14 +548,14 @@ TEST_F(TestDeConvolutionFp32, DeConvTest2) {
float *correct;
int total_size = DeConvTestInit2(&inputs_, &outputs_, deconv_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->thread_num_ = 4;
ctx->thread_num_ = 1;
kernel::DeConvolutionCPUKernel *deconv =
new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx, nullptr);
deconv->Init();
deconv->Run();
EXPECT_EQ(0, lite::CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size));
delete deconv_param;
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
@ -635,7 +635,6 @@ TEST_F(TestDeConvolutionFp32, DeConvTest3) {
deconv->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
@ -723,7 +722,6 @@ TEST_F(TestDeConvolutionFp32, DeConvTest4) {
uint64_t time_avg = cost / loop_count;
printf("deconv fp32 average time : %f ms\n", time_avg / 1000.0f);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;

Loading…
Cancel
Save