diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 619fcade8c..53ea430a9d 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -35,7 +35,7 @@ int MatMul::InferShape(std::vector inputs_, std::vector a_shape = input0->shape(); std::vector b_shape = input1->shape(); - if (a_shape.size() < 3 || b_shape.size() < 3) { + if (a_shape.size() < 2 || b_shape.size() < 2) { MS_LOG(ERROR) << "inputs shape is invalid"; return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index 9bc41f02e1..ba17b9524a 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -24,24 +24,20 @@ int Power::InferShape(std::vector inputs, std::vectorprimitive != nullptr); auto x_tensor = inputs[0]; MS_ASSERT(x_tensor != nullptr); - auto exp_tensor = inputs[1]; - MS_ASSERT(exp_tensor != nullptr); + tensor::Tensor *exp_tensor = nullptr; + if (inputs.size() == 2) { + exp_tensor = inputs[1]; + MS_ASSERT(exp_tensor != nullptr); + } auto output_tensor = outputs[0]; MS_ASSERT(output_tensor != nullptr); - if (inputs.size() < 2) { - MS_LOG(ERROR) << "input size" << inputs.size() << " is error!"; - return RET_INPUT_TENSOR_ERROR; - } - if (exp_tensor->shape() != x_tensor->shape() && exp_tensor->shape().size() != 1) { - MS_LOG(ERROR) << "Power inputs shape is not equal!"; - return RET_INPUT_TENSOR_ERROR; + if (exp_tensor) { + if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) { + MS_LOG(ERROR) << "Power inputs shape or type is not equal!"; + return RET_INPUT_TENSOR_ERROR; + } } - int exp_size = std::accumulate(exp_tensor->shape().begin(), exp_tensor->shape().end(), 1, std::multiplies()); - if (x_tensor->data_type() != exp_tensor->data_type() && exp_size != 1) { - MS_LOG(ERROR) << "Exponent tensor's shape is wrong"; - return RET_INPUT_TENSOR_ERROR; - } output_tensor->SetFormat(x_tensor->GetFormat()); output_tensor->set_shape(x_tensor->shape()); output_tensor->set_data_type(x_tensor->data_type()); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc index eb88cfb4b3..56024abf33 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc @@ -69,4 +69,5 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vectorshape(); - auto o_shape = outputs_[0]->shape(); - for (int i = 0; i < x_shape.size() - 2; ++i) { - batch *= x_shape[i]; + auto a_shape = inputs_[0]->shape(); + auto c_shape = outputs_[0]->shape(); + for (int i = 0; i < a_shape.size() - 2; ++i) { + batch *= a_shape[i]; } params_->batch = batch; - params_->row_ = o_shape[o_shape.size() - 2]; - params_->col_ = o_shape[o_shape.size() - 1]; - params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; + params_->row_ = c_shape[c_shape.size() - 2]; + params_->col_ = c_shape[c_shape.size() - 1]; + params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; params_->row_8_ = UP_ROUND(params_->row_, 8); params_->col_8_ = UP_ROUND(params_->col_, 8); thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc index 26e9a0c5ea..df4e5974d6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc @@ -51,15 +51,19 @@ int PowerCPUKernel::Run() { int PowerCPUKernel::RunImpl(int task_id) { auto x_addr = reinterpret_cast(inputs_[0]->Data()); - auto exp_addr = reinterpret_cast(inputs_[1]->Data()); auto output_addr = reinterpret_cast(outputs_[0]->Data()); auto size = inputs_[0]->ElementsNum(); int stride = UP_DIV(size, thread_count_); int len = MSMIN(stride, size - stride * task_id); - bool broadcast = (inputs_[1]->ElementsNum() == 1) ? true : false; + float *exp_addr = nullptr; + bool broadcast = true; + if (inputs_.size() == 2) { + exp_addr = reinterpret_cast(inputs_[1]->Data()); + broadcast = false; + } float *cur_exp; if (broadcast) { - cur_exp = exp_addr; + cur_exp = &power_; } else { cur_exp = exp_addr + stride * task_id; } @@ -73,8 +77,7 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vectorthread_num_), + power_(reinterpret_cast(opParameter)->power_), scale_(reinterpret_cast(opParameter)->scale_), shift_(reinterpret_cast(opParameter)->shift_) {} ~PowerCPUKernel() override = default; @@ -42,6 +43,7 @@ class PowerCPUKernel : public LiteKernel { private: const lite::Context *ctx_; int thread_count_; + float power_; float scale_; float shift_; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s index b33c71d34e..2231e8debf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s @@ -1,9 +1,9 @@ #ifdef __aarch64__ .text .align 5 - .global MatMulFloatNeon64 + .global MatmulFloatNeon64 #ifndef __APPLE__ - .type MatMulFloatNeon64, %function + .type MatmulFloatNeon64, %function #endif // A: LM [row_8 * depth] col_8_major @@ -46,41 +46,39 @@ // accumulators 8x8 block ///////////////////////////////////////////////////////////////////////////////// // -// void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col) +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col) // x0: a // x1: b // x2: c // x3: bias -// v0.s[0]: maxf -// v1.s[0]: minf -// w4: depth -// w5: row -// w6: col +// w4: act_type +// w5: depth +// w6: row +// w7: col -MatMulFloatNeon64: +MatmulFloatNeon64: sub sp, sp, #128 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - mov w7, v0.s[0] - mov w8, v1.s[0] - mov w9, 0 // rm col offset - mov w10, 0 // lm row offset + mov w9, #0 // rm col offset + mov w10, #0 // lm row offset mov w18, #32 // sizeof(float)*8 - mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth - + mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth + mov x11, x3 // bias flag L1: - cmp w9, w6 + cmp w9, w7 beq End1 - mov w10, 0 // reset lm row offset + mov w10, #0 // reset lm row offset mov x12, x0 // reload lm ptr - mov x14, x3 // reload bias ptr L2: cmp w10, w6 beq End2 - mov w13, w4 // reload depth + mov x16, x1 // reload rm ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr dup v16.4s, wzr dup v17.4s, wzr dup v18.4s, wzr @@ -103,7 +101,7 @@ OptLoopMul4: blt CommLoopMul ld1 {v0.4s, v1.4s}, [x12], #32 - ld1 {v8.4s, v9.4s}, [x1], #32 + ld1 {v8.4s, v9.4s}, [x16], #32 fmla v16.4s, v8.4s, v0.s[0] fmla v17.4s, v9.4s, v0.s[0] fmla v18.4s, v8.4s, v0.s[1] @@ -112,7 +110,7 @@ OptLoopMul4: fmla v21.4s, v9.4s, v0.s[2] fmla v22.4s, v8.4s, v0.s[3] fmla v23.4s, v9.4s, v0.s[3] - ld1 {v10.4s, v11.4s}, [x1], #32 + ld1 {v10.4s, v11.4s}, [x16], #32 fmla v24.4s, v8.4s, v1.s[0] fmla v25.4s, v9.4s, v1.s[0] fmla v26.4s, v8.4s, v1.s[1] @@ -130,7 +128,7 @@ OptLoopMul4: fmla v21.4s, v11.4s, v2.s[2] fmla v22.4s, v10.4s, v2.s[3] fmla v23.4s, v11.4s, v2.s[3] - ld1 {v12.4s, v13.4s}, [x1], #32 + ld1 {v12.4s, v13.4s}, [x16], #32 fmla v24.4s, v10.4s, v3.s[0] fmla v25.4s, v11.4s, v3.s[0] fmla v26.4s, v10.4s, v3.s[1] @@ -153,7 +151,7 @@ OptLoopMul4: fmla v25.4s, v13.4s, v5.s[0] fmla v26.4s, v12.4s, v5.s[1] fmla v27.4s, v13.4s, v5.s[1] - ld1 {v14.4s, v15.4s}, [x1], #32 + ld1 {v14.4s, v15.4s}, [x16], #32 fmla v28.4s, v12.4s, v5.s[2] fmla v29.4s, v13.4s, v5.s[2] fmla v30.4s, v12.4s, v5.s[3] @@ -182,7 +180,7 @@ CommLoopMul: blt Bias ld1 {v0.4s, v1.4s}, [x12], #32 - ld1 {v2.4s, v3.4s}, [x1], #32 + ld1 {v2.4s, v3.4s}, [x16], #32 fmla v16.4s, v2.4s, v0.s[0] fmla v17.4s, v3.4s, v0.s[0] fmla v18.4s, v2.4s, v0.s[1] @@ -203,8 +201,7 @@ CommLoopMul: b CommLoopMul Bias: - cmp x3, #0 - beq Relu + cbz x11, Activation ld1 {v0.4s}, [x14], #16 ld1 {v1.4s}, [x14], #16 fadd v16.4s, v16.4s, v0.4s @@ -224,9 +221,34 @@ Bias: fadd v30.4s, v30.4s, v0.4s fadd v31.4s, v31.4s, v1.4s +Activation: + cmp w4, #2 + beq Relu6 + cmp w4, #1 + beq Relu + b TransToOut +Relu6: + mov w8, #6 + dup v15.4s, w8 + scvtf v15.4s, v15.4s + fmin v16.4s, v16.4s, v15.4s + fmin v17.4s, v17.4s, v15.4s + fmin v18.4s, v18.4s, v15.4s + fmin v19.4s, v19.4s, v15.4s + fmin v20.4s, v20.4s, v15.4s + fmin v21.4s, v21.4s, v15.4s + fmin v22.4s, v22.4s, v15.4s + fmin v23.4s, v23.4s, v15.4s + fmin v24.4s, v24.4s, v15.4s + fmin v25.4s, v25.4s, v15.4s + fmin v26.4s, v26.4s, v15.4s + fmin v27.4s, v27.4s, v15.4s + fmin v28.4s, v28.4s, v15.4s + fmin v29.4s, v29.4s, v15.4s + fmin v30.4s, v30.4s, v15.4s + fmin v31.4s, v31.4s, v15.4s Relu: - dup v15.4s, w7 - dup v14.4s, w8 + dup v14.4s, wzr fmax v16.4s, v16.4s, v14.4s fmax v17.4s, v17.4s, v14.4s fmax v18.4s, v18.4s, v14.4s @@ -244,24 +266,6 @@ Relu: fmax v30.4s, v30.4s, v14.4s fmax v31.4s, v31.4s, v14.4s - fmin v16.4s, v16.4s, v15.4s - fmin v17.4s, v17.4s, v15.4s - fmin v18.4s, v18.4s, v15.4s - fmin v19.4s, v19.4s, v15.4s - fmin v20.4s, v20.4s, v15.4s - fmin v20.4s, v20.4s, v15.4s - fmin v21.4s, v21.4s, v15.4s - fmin v22.4s, v22.4s, v15.4s - fmin v23.4s, v23.4s, v15.4s - fmin v24.4s, v24.4s, v15.4s - fmin v25.4s, v25.4s, v15.4s - fmin v26.4s, v26.4s, v15.4s - fmin v27.4s, v27.4s, v15.4s - fmin v28.4s, v28.4s, v15.4s - fmin v29.4s, v29.4s, v15.4s - fmin v30.4s, v30.4s, v15.4s - fmin v31.4s, v31.4s, v15.4s - TransToOut: st1 {v16.4s}, [x2], #16 st1 {v17.4s}, [x2], #16 @@ -280,11 +284,13 @@ TransToOut: st1 {v30.4s}, [x2], #16 st1 {v31.4s}, [x2], #16 - add w10, w10, #8 // lhs row offset + 8 + add w10, w10, #8 // lm row offset + 8 b L2 End2: - add w9, w9, #8 // rhs col offset + 8 + add w9, w9, #8 // rm col offset + 8 + add x1, x1, x15 // rm ptr + stride + add x3, x3, x18 // bias ptr + stride b L1 End1: diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc index 6c841f7f39..3155fe9561 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc @@ -42,7 +42,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) float *dst_c = dst_r + ci * C8NUM; /* 8x4 row-major to col-major */ -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 size_t stride = col * 4; asm volatile( "mov x10, %[src_c]\n" @@ -156,6 +156,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActT void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, int col_8_) { +#ifdef __aarch64__ + MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_); +#else MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); - return; +#endif } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h index e92f6004b7..ad2e4531bb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h @@ -32,8 +32,8 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa extern "C" { #endif #ifdef __aarch64__ -void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, - int row, int col); +void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col); #endif #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h index 421d582c5e..65915b5ca4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h @@ -157,10 +157,10 @@ inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32 // quantize from float to int8 inline void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { for (int i = 0; i < length; ++i) { - int r = (int)round(input_data[i] / scale + zero_point); - int8_t q = r > CHAR_MAX ? (int8_t)CHAR_MAX : (int8_t)r; + int q = (int)round(input_data[i] / scale + zero_point); + q = q > CHAR_MAX ? CHAR_MAX : q; q = q < CHAR_MIN ? CHAR_MIN : q; - output_data[i] = q; + output_data[i] = (int8_t)q; } } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc index 8ce0da32d0..2d0fe33be1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc @@ -201,19 +201,108 @@ TEST_F(TestMatMulFp32, simple) { 0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552, -0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459, 0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343}; - std::vector a_shape = {1, 2, 8}; - std::vector b_shape = {1, 8, 3}; - std::vector c_shape = {1, 2, 3}; + std::vector a_shape = {2, 8}; + std::vector b_shape = {8, 3}; + std::vector c_shape = {2, 3}; int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; - ctx->thread_num_ = 2; + ctx->thread_num_ = 1; auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); mm->Init(); mm->Run(); float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779, -0.3049793541431427, -0.027687929570674896, -0.18109679222106934}; CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); - delete matmul_param; + delete mm; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; +} + +TEST_F(TestMatMulFp32, simple2) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = false; + matmul_param->has_bias_ = false; + float a[25 * 12] = { + 1, 4, 10, 2, 3, 10, 4, 6, 5, 6, 9, 5, 4, 2, 5, 7, 5, 8, 0, 5, 1, 0, 10, 3, 0, 4, 2, 3, 2, 9, + 8, 9, 5, 4, 4, 9, 7, 4, 2, 6, 10, 2, 1, 7, 2, 10, 5, 10, 1, 2, 2, 9, 8, 8, 2, 5, 6, 3, 2, 8, + 3, 3, 7, 3, 0, 4, 10, 9, 0, 5, 2, 6, 1, 10, 7, 6, 9, 6, 0, 3, 8, 0, 8, 3, 10, 4, 7, 7, 0, 5, + 6, 5, 4, 6, 5, 5, 3, 7, 1, 9, 3, 2, 8, 3, 0, 0, 6, 7, 6, 3, 6, 5, 1, 0, 4, 2, 6, 0, 7, 7, + 7, 4, 9, 8, 6, 1, 10, 10, 7, 3, 0, 6, 9, 4, 1, 4, 4, 3, 1, 6, 7, 3, 8, 6, 4, 10, 9, 8, 10, 5, + 2, 3, 8, 10, 0, 8, 2, 9, 5, 3, 3, 0, 1, 8, 1, 1, 2, 0, 1, 5, 5, 0, 1, 10, 9, 9, 3, 6, 7, 1, + 2, 3, 7, 5, 0, 8, 2, 8, 7, 8, 9, 10, 4, 2, 5, 3, 10, 1, 5, 0, 6, 2, 3, 5, 5, 1, 5, 5, 5, 1, + 8, 2, 6, 9, 10, 4, 9, 1, 10, 9, 8, 2, 5, 2, 4, 2, 3, 7, 7, 2, 9, 10, 10, 10, 5, 1, 8, 8, 10, 3, + 2, 10, 2, 6, 5, 9, 10, 6, 10, 0, 5, 5, 4, 0, 9, 4, 4, 9, 4, 6, 4, 2, 5, 2, 10, 5, 9, 8, 1, 4, + 7, 9, 6, 5, 0, 3, 6, 4, 3, 10, 6, 4, 10, 5, 8, 8, 9, 4, 5, 6, 8, 9, 2, 2, 4, 4, 8, 0, 4, 5}; + float b[12 * 36] = { + 6, 6, 7, 2, 1, 10, 3, 7, 7, 5, 5, 5, 6, 6, 9, 8, 4, 10, 9, 5, 5, 7, 2, 1, 7, 9, 10, 0, 3, + 10, 4, 2, 7, 4, 3, 10, 5, 3, 1, 3, 3, 1, 9, 6, 7, 6, 6, 6, 7, 6, 10, 8, 2, 8, 5, 2, 1, 7, + 5, 9, 10, 9, 0, 8, 10, 2, 3, 4, 0, 7, 5, 5, 0, 9, 6, 1, 6, 7, 4, 1, 0, 3, 0, 7, 3, 0, 10, + 7, 6, 4, 10, 7, 6, 5, 10, 2, 10, 9, 10, 6, 9, 10, 8, 8, 5, 3, 9, 10, 8, 3, 3, 4, 6, 2, 6, 0, + 4, 0, 3, 4, 1, 0, 3, 10, 5, 4, 0, 2, 3, 2, 4, 3, 10, 5, 4, 10, 8, 2, 0, 4, 0, 5, 8, 0, 1, + 10, 0, 3, 1, 1, 9, 4, 0, 3, 0, 1, 6, 3, 10, 0, 10, 3, 3, 0, 6, 7, 3, 2, 3, 5, 10, 6, 1, 5, + 7, 3, 3, 1, 1, 10, 5, 4, 0, 8, 4, 0, 9, 6, 2, 3, 6, 10, 10, 0, 2, 2, 1, 2, 7, 10, 9, 7, 10, + 2, 8, 5, 3, 7, 0, 4, 3, 4, 8, 3, 8, 0, 5, 5, 6, 9, 10, 0, 1, 5, 6, 6, 4, 7, 7, 6, 7, 9, + 5, 5, 6, 0, 4, 1, 2, 6, 8, 4, 10, 4, 10, 9, 8, 8, 1, 7, 1, 8, 1, 0, 10, 8, 8, 1, 8, 0, 10, + 3, 1, 7, 0, 10, 5, 0, 2, 8, 4, 1, 8, 1, 6, 7, 1, 8, 3, 4, 3, 4, 7, 0, 9, 1, 1, 4, 8, 10, + 0, 3, 3, 2, 7, 9, 3, 3, 10, 10, 9, 4, 4, 0, 7, 1, 1, 3, 5, 1, 4, 8, 5, 7, 3, 9, 10, 1, 5, + 9, 7, 4, 10, 10, 3, 4, 3, 5, 1, 10, 5, 2, 3, 3, 0, 3, 1, 2, 8, 7, 4, 2, 0, 8, 7, 6, 6, 6, + 5, 7, 5, 5, 3, 0, 4, 10, 1, 7, 8, 9, 6, 7, 0, 1, 9, 3, 1, 6, 8, 4, 9, 0, 3, 2, 4, 0, 2, + 7, 2, 2, 8, 0, 4, 1, 3, 2, 6, 8, 5, 5, 2, 3, 9, 0, 1, 7, 6, 9, 1, 10, 4, 10, 5, 10, 0, 9, + 5, 1, 6, 2, 9, 9, 8, 8, 10, 8, 1, 6, 5, 8, 8, 6, 4, 8, 10, 3, 0, 6, 2, 8, 4, 2}; + std::vector a_shape = {25, 12}; + std::vector b_shape = {12, 36}; + std::vector c_shape = {25, 36}; + int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + auto ctx = new lite::Context; + ctx->thread_num_ = 2; + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + mm->Init(); + mm->Run(); + float correct[] = { + 263, 386, 184, 309, 338, 244, 359, 294, 252, 254, 273, 353, 320, 183, 412, 273, 271, 307, 329, 314, 391, 261, 400, + 280, 416, 399, 355, 427, 373, 302, 288, 349, 336, 241, 349, 393, 226, 285, 134, 209, 264, 163, 281, 212, 219, 171, + 221, 228, 227, 131, 289, 196, 204, 270, 238, 205, 303, 196, 280, 156, 311, 284, 282, 335, 243, 245, 181, 188, 280, + 142, 229, 256, 270, 310, 184, 377, 323, 187, 345, 295, 255, 262, 259, 332, 310, 222, 357, 275, 253, 301, 296, 254, + 316, 221, 323, 322, 370, 353, 281, 386, 363, 240, 245, 301, 270, 263, 275, 292, 278, 388, 199, 324, 252, 336, 385, + 300, 257, 274, 215, 243, 272, 230, 485, 335, 343, 366, 293, 272, 337, 313, 310, 305, 385, 421, 377, 398, 343, 262, + 249, 309, 258, 280, 286, 411, 268, 337, 127, 307, 244, 185, 368, 263, 178, 205, 223, 281, 288, 154, 339, 255, 295, + 250, 241, 236, 289, 240, 296, 261, 361, 333, 282, 399, 315, 202, 203, 272, 231, 229, 300, 273, 199, 253, 246, 315, + 307, 213, 257, 202, 243, 230, 163, 288, 220, 212, 361, 314, 219, 296, 300, 217, 274, 196, 285, 264, 351, 339, 312, + 289, 338, 282, 256, 274, 214, 243, 228, 302, 276, 394, 110, 224, 274, 163, 395, 296, 231, 223, 289, 311, 331, 177, + 405, 236, 294, 293, 264, 213, 314, 258, 330, 270, 403, 381, 305, 450, 382, 250, 248, 287, 278, 211, 324, 374, 306, + 350, 246, 298, 309, 305, 315, 289, 292, 256, 264, 341, 295, 218, 427, 382, 272, 359, 335, 286, 333, 263, 327, 275, + 448, 423, 380, 369, 397, 330, 260, 329, 285, 284, 333, 397, 259, 258, 146, 261, 281, 156, 248, 234, 236, 219, 220, + 207, 233, 173, 326, 316, 223, 301, 237, 145, 202, 181, 209, 236, 357, 279, 265, 332, 352, 230, 165, 219, 154, 233, + 189, 237, 246, 316, 147, 197, 247, 221, 212, 256, 201, 208, 239, 220, 231, 153, 322, 263, 237, 278, 254, 178, 215, + 164, 217, 211, 326, 295, 284, 306, 354, 247, 178, 244, 216, 199, 229, 308, 298, 409, 306, 359, 359, 273, 388, 291, + 301, 281, 239, 395, 323, 290, 505, 398, 370, 381, 365, 235, 344, 268, 340, 351, 473, 481, 445, 415, 481, 373, 354, + 365, 284, 309, 338, 469, 285, 336, 166, 244, 245, 247, 305, 304, 273, 233, 281, 260, 276, 218, 364, 241, 255, 330, + 257, 213, 296, 221, 252, 251, 325, 355, 301, 341, 319, 246, 206, 243, 295, 210, 249, 357, 328, 481, 196, 345, 276, + 338, 493, 349, 236, 299, 265, 388, 383, 224, 573, 425, 411, 354, 353, 340, 363, 385, 414, 387, 541, 528, 412, 515, + 486, 298, 320, 438, 254, 361, 454, 494, 120, 156, 151, 140, 176, 99, 231, 113, 197, 132, 113, 190, 134, 171, 264, + 169, 137, 219, 165, 92, 172, 145, 188, 186, 225, 260, 166, 216, 225, 161, 173, 134, 147, 130, 152, 218, 226, 273, + 205, 314, 331, 157, 311, 242, 289, 228, 238, 346, 285, 223, 344, 235, 194, 282, 274, 238, 358, 207, 333, 270, 345, + 345, 302, 339, 309, 273, 284, 291, 297, 219, 261, 338, 319, 396, 200, 356, 349, 311, 377, 330, 280, 280, 308, 351, + 311, 204, 421, 319, 294, 348, 328, 346, 387, 261, 403, 335, 434, 428, 333, 467, 422, 270, 254, 370, 345, 285, 381, + 378, 200, 347, 110, 195, 189, 184, 252, 242, 134, 191, 179, 205, 256, 140, 349, 219, 287, 216, 225, 155, 223, 203, + 203, 196, 295, 281, 321, 291, 292, 235, 219, 255, 177, 186, 213, 349, 286, 389, 180, 262, 306, 275, 269, 284, 257, + 239, 256, 262, 270, 189, 410, 306, 302, 297, 244, 226, 335, 213, 276, 257, 371, 351, 398, 376, 378, 289, 265, 355, + 258, 252, 286, 446, 274, 419, 214, 263, 277, 296, 317, 276, 202, 240, 214, 287, 292, 174, 454, 366, 352, 328, 342, + 247, 300, 273, 300, 232, 440, 401, 436, 374, 394, 351, 269, 317, 247, 255, 312, 416, 384, 533, 202, 336, 369, 322, + 449, 373, 291, 282, 343, 409, 416, 198, 526, 383, 405, 363, 355, 355, 478, 348, 435, 296, 544, 490, 519, 540, 449, + 390, 345, 444, 378, 307, 454, 542, 356, 394, 179, 370, 364, 152, 424, 370, 316, 291, 358, 420, 419, 267, 429, 323, + 311, 348, 320, 232, 344, 260, 344, 369, 472, 424, 339, 479, 470, 297, 298, 350, 300, 302, 340, 389, 211, 314, 186, + 248, 277, 184, 294, 217, 204, 184, 203, 311, 262, 154, 324, 221, 233, 249, 283, 241, 331, 210, 318, 191, 341, 330, + 331, 323, 278, 289, 255, 259, 294, 174, 280, 323, 295, 348, 303, 319, 321, 286, 365, 266, 310, 251, 240, 406, 302, + 265, 457, 396, 297, 366, 350, 270, 343, 271, 347, 314, 469, 476, 396, 375, 428, 351, 315, 341, 291, 296, 361, 428, + 383, 442, 232, 360, 387, 279, 391, 349, 348, 288, 334, 374, 360, 262, 485, 391, 362, 379, 296, 262, 406, 270, 346, + 346, 486, 451, 451, 490, 475, 339, 319, 409, 315, 324, 367, 493, 286, 348, 185, 240, 287, 214, 312, 265, 237, 218, + 261, 316, 279, 186, 377, 319, 279, 304, 281, 207, 261, 209, 287, 270, 415, 378, 312, 388, 423, 273, 230, 294, 239, + 243, 319, 346}; + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); delete mm; for (auto t : inputs_) delete t; for (auto t : outputs_) delete t; @@ -243,7 +332,6 @@ TEST_F(TestMatMulFp32, simple_transb) { mm->Run(); float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031}; CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); - delete matmul_param; delete mm; for (auto t : inputs_) delete t; for (auto t : outputs_) delete t; @@ -298,9 +386,7 @@ TEST_F(TestMatMulFp32, batch) { 8.869029998779297, 25.034008026123047}; float *output = reinterpret_cast(outputs_[0]->Data()); - for (int i = 0; i < 18; ++i) printf("%f ", output[i]); CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); - delete matmul_param; delete mm; for (auto t : inputs_) delete t; for (auto t : outputs_) delete t;