Fix matmul asm bugs

pull/4133/head
zhanyuan 5 years ago
parent 7ec1cb4e3d
commit b99d8590a1

@ -35,7 +35,7 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape();
if (a_shape.size() < 3 || b_shape.size() < 3) {
if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid";
return RET_INPUT_TENSOR_ERROR;
}

@ -24,24 +24,20 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
MS_ASSERT(this->primitive != nullptr);
auto x_tensor = inputs[0];
MS_ASSERT(x_tensor != nullptr);
auto exp_tensor = inputs[1];
MS_ASSERT(exp_tensor != nullptr);
tensor::Tensor *exp_tensor = nullptr;
if (inputs.size() == 2) {
exp_tensor = inputs[1];
MS_ASSERT(exp_tensor != nullptr);
}
auto output_tensor = outputs[0];
MS_ASSERT(output_tensor != nullptr);
if (inputs.size() < 2) {
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return RET_INPUT_TENSOR_ERROR;
}
if (exp_tensor->shape() != x_tensor->shape() && exp_tensor->shape().size() != 1) {
MS_LOG(ERROR) << "Power inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
if (exp_tensor) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
}
int exp_size = std::accumulate(exp_tensor->shape().begin(), exp_tensor->shape().end(), 1, std::multiplies<int>());
if (x_tensor->data_type() != exp_tensor->data_type() && exp_size != 1) {
MS_LOG(ERROR) << "Exponent tensor's shape is wrong";
return RET_INPUT_TENSOR_ERROR;
}
output_tensor->SetFormat(x_tensor->GetFormat());
output_tensor->set_shape(x_tensor->shape());
output_tensor->set_data_type(x_tensor->data_type());

@ -69,4 +69,5 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatmulKernelCreator)
} // namespace mindspore::kernel

@ -34,15 +34,15 @@ int MatmulCPUKernel::ReSize() { return RET_OK; }
int MatmulCPUKernel::Init() {
int batch = 1;
auto x_shape = inputs_[0]->shape();
auto o_shape = outputs_[0]->shape();
for (int i = 0; i < x_shape.size() - 2; ++i) {
batch *= x_shape[i];
auto a_shape = inputs_[0]->shape();
auto c_shape = outputs_[0]->shape();
for (int i = 0; i < a_shape.size() - 2; ++i) {
batch *= a_shape[i];
}
params_->batch = batch;
params_->row_ = o_shape[o_shape.size() - 2];
params_->col_ = o_shape[o_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
params_->row_ = c_shape[c_shape.size() - 2];
params_->col_ = c_shape[c_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));

@ -51,15 +51,19 @@ int PowerCPUKernel::Run() {
int PowerCPUKernel::RunImpl(int task_id) {
auto x_addr = reinterpret_cast<float *>(inputs_[0]->Data());
auto exp_addr = reinterpret_cast<float *>(inputs_[1]->Data());
auto output_addr = reinterpret_cast<float *>(outputs_[0]->Data());
auto size = inputs_[0]->ElementsNum();
int stride = UP_DIV(size, thread_count_);
int len = MSMIN(stride, size - stride * task_id);
bool broadcast = (inputs_[1]->ElementsNum() == 1) ? true : false;
float *exp_addr = nullptr;
bool broadcast = true;
if (inputs_.size() == 2) {
exp_addr = reinterpret_cast<float *>(inputs_[1]->Data());
broadcast = false;
}
float *cur_exp;
if (broadcast) {
cur_exp = exp_addr;
cur_exp = &power_;
} else {
cur_exp = exp_addr + stride * task_id;
}
@ -73,8 +77,7 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Power);
auto *kernel =
new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx);
auto *kernel = new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerCPUKernel fail!";
return nullptr;

@ -30,6 +30,7 @@ class PowerCPUKernel : public LiteKernel {
: LiteKernel(param, inputs, outputs),
ctx_(ctx),
thread_count_(ctx->thread_num_),
power_(reinterpret_cast<PowerParameter *>(opParameter)->power_),
scale_(reinterpret_cast<PowerParameter *>(opParameter)->scale_),
shift_(reinterpret_cast<PowerParameter *>(opParameter)->shift_) {}
~PowerCPUKernel() override = default;
@ -42,6 +43,7 @@ class PowerCPUKernel : public LiteKernel {
private:
const lite::Context *ctx_;
int thread_count_;
float power_;
float scale_;
float shift_;
};

@ -1,9 +1,9 @@
#ifdef __aarch64__
.text
.align 5
.global MatMulFloatNeon64
.global MatmulFloatNeon64
#ifndef __APPLE__
.type MatMulFloatNeon64, %function
.type MatmulFloatNeon64, %function
#endif
// A: LM [row_8 * depth] col_8_major
@ -46,41 +46,39 @@
// accumulators 8x8 block
/////////////////////////////////////////////////////////////////////////////////
//
// void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col)
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col)
// x0: a
// x1: b
// x2: c
// x3: bias
// v0.s[0]: maxf
// v1.s[0]: minf
// w4: depth
// w5: row
// w6: col
// w4: act_type
// w5: depth
// w6: row
// w7: col
MatMulFloatNeon64:
MatmulFloatNeon64:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov w7, v0.s[0]
mov w8, v1.s[0]
mov w9, 0 // rm col offset
mov w10, 0 // lm row offset
mov w9, #0 // rm col offset
mov w10, #0 // lm row offset
mov w18, #32 // sizeof(float)*8
mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth
mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth
mov x11, x3 // bias flag
L1:
cmp w9, w6
cmp w9, w7
beq End1
mov w10, 0 // reset lm row offset
mov w10, #0 // reset lm row offset
mov x12, x0 // reload lm ptr
mov x14, x3 // reload bias ptr
L2:
cmp w10, w6
beq End2
mov w13, w4 // reload depth
mov x16, x1 // reload rm ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
@ -103,7 +101,7 @@ OptLoopMul4:
blt CommLoopMul
ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v8.4s, v9.4s}, [x1], #32
ld1 {v8.4s, v9.4s}, [x16], #32
fmla v16.4s, v8.4s, v0.s[0]
fmla v17.4s, v9.4s, v0.s[0]
fmla v18.4s, v8.4s, v0.s[1]
@ -112,7 +110,7 @@ OptLoopMul4:
fmla v21.4s, v9.4s, v0.s[2]
fmla v22.4s, v8.4s, v0.s[3]
fmla v23.4s, v9.4s, v0.s[3]
ld1 {v10.4s, v11.4s}, [x1], #32
ld1 {v10.4s, v11.4s}, [x16], #32
fmla v24.4s, v8.4s, v1.s[0]
fmla v25.4s, v9.4s, v1.s[0]
fmla v26.4s, v8.4s, v1.s[1]
@ -130,7 +128,7 @@ OptLoopMul4:
fmla v21.4s, v11.4s, v2.s[2]
fmla v22.4s, v10.4s, v2.s[3]
fmla v23.4s, v11.4s, v2.s[3]
ld1 {v12.4s, v13.4s}, [x1], #32
ld1 {v12.4s, v13.4s}, [x16], #32
fmla v24.4s, v10.4s, v3.s[0]
fmla v25.4s, v11.4s, v3.s[0]
fmla v26.4s, v10.4s, v3.s[1]
@ -153,7 +151,7 @@ OptLoopMul4:
fmla v25.4s, v13.4s, v5.s[0]
fmla v26.4s, v12.4s, v5.s[1]
fmla v27.4s, v13.4s, v5.s[1]
ld1 {v14.4s, v15.4s}, [x1], #32
ld1 {v14.4s, v15.4s}, [x16], #32
fmla v28.4s, v12.4s, v5.s[2]
fmla v29.4s, v13.4s, v5.s[2]
fmla v30.4s, v12.4s, v5.s[3]
@ -182,7 +180,7 @@ CommLoopMul:
blt Bias
ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v2.4s, v3.4s}, [x1], #32
ld1 {v2.4s, v3.4s}, [x16], #32
fmla v16.4s, v2.4s, v0.s[0]
fmla v17.4s, v3.4s, v0.s[0]
fmla v18.4s, v2.4s, v0.s[1]
@ -203,8 +201,7 @@ CommLoopMul:
b CommLoopMul
Bias:
cmp x3, #0
beq Relu
cbz x11, Activation
ld1 {v0.4s}, [x14], #16
ld1 {v1.4s}, [x14], #16
fadd v16.4s, v16.4s, v0.4s
@ -224,9 +221,34 @@ Bias:
fadd v30.4s, v30.4s, v0.4s
fadd v31.4s, v31.4s, v1.4s
Activation:
cmp w4, #2
beq Relu6
cmp w4, #1
beq Relu
b TransToOut
Relu6:
mov w8, #6
dup v15.4s, w8
scvtf v15.4s, v15.4s
fmin v16.4s, v16.4s, v15.4s
fmin v17.4s, v17.4s, v15.4s
fmin v18.4s, v18.4s, v15.4s
fmin v19.4s, v19.4s, v15.4s
fmin v20.4s, v20.4s, v15.4s
fmin v21.4s, v21.4s, v15.4s
fmin v22.4s, v22.4s, v15.4s
fmin v23.4s, v23.4s, v15.4s
fmin v24.4s, v24.4s, v15.4s
fmin v25.4s, v25.4s, v15.4s
fmin v26.4s, v26.4s, v15.4s
fmin v27.4s, v27.4s, v15.4s
fmin v28.4s, v28.4s, v15.4s
fmin v29.4s, v29.4s, v15.4s
fmin v30.4s, v30.4s, v15.4s
fmin v31.4s, v31.4s, v15.4s
Relu:
dup v15.4s, w7
dup v14.4s, w8
dup v14.4s, wzr
fmax v16.4s, v16.4s, v14.4s
fmax v17.4s, v17.4s, v14.4s
fmax v18.4s, v18.4s, v14.4s
@ -244,24 +266,6 @@ Relu:
fmax v30.4s, v30.4s, v14.4s
fmax v31.4s, v31.4s, v14.4s
fmin v16.4s, v16.4s, v15.4s
fmin v17.4s, v17.4s, v15.4s
fmin v18.4s, v18.4s, v15.4s
fmin v19.4s, v19.4s, v15.4s
fmin v20.4s, v20.4s, v15.4s
fmin v20.4s, v20.4s, v15.4s
fmin v21.4s, v21.4s, v15.4s
fmin v22.4s, v22.4s, v15.4s
fmin v23.4s, v23.4s, v15.4s
fmin v24.4s, v24.4s, v15.4s
fmin v25.4s, v25.4s, v15.4s
fmin v26.4s, v26.4s, v15.4s
fmin v27.4s, v27.4s, v15.4s
fmin v28.4s, v28.4s, v15.4s
fmin v29.4s, v29.4s, v15.4s
fmin v30.4s, v30.4s, v15.4s
fmin v31.4s, v31.4s, v15.4s
TransToOut:
st1 {v16.4s}, [x2], #16
st1 {v17.4s}, [x2], #16
@ -280,11 +284,13 @@ TransToOut:
st1 {v30.4s}, [x2], #16
st1 {v31.4s}, [x2], #16
add w10, w10, #8 // lhs row offset + 8
add w10, w10, #8 // lm row offset + 8
b L2
End2:
add w9, w9, #8 // rhs col offset + 8
add w9, w9, #8 // rm col offset + 8
add x1, x1, x15 // rm ptr + stride
add x3, x3, x18 // bias ptr + stride
b L1
End1:

@ -42,7 +42,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
float *dst_c = dst_r + ci * C8NUM;
/* 8x4 row-major to col-major */
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
size_t stride = col * 4;
asm volatile(
"mov x10, %[src_c]\n"
@ -156,6 +156,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActT
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
#ifdef __aarch64__
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_);
return;
#endif
}

@ -32,8 +32,8 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
extern "C" {
#endif
#ifdef __aarch64__
void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth,
int row, int col);
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col);
#endif
#ifdef __cplusplus
}

@ -157,10 +157,10 @@ inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32
// quantize from float to int8
inline void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) {
for (int i = 0; i < length; ++i) {
int r = (int)round(input_data[i] / scale + zero_point);
int8_t q = r > CHAR_MAX ? (int8_t)CHAR_MAX : (int8_t)r;
int q = (int)round(input_data[i] / scale + zero_point);
q = q > CHAR_MAX ? CHAR_MAX : q;
q = q < CHAR_MIN ? CHAR_MIN : q;
output_data[i] = q;
output_data[i] = (int8_t)q;
}
}

@ -201,19 +201,108 @@ TEST_F(TestMatMulFp32, simple) {
0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552,
-0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459,
0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343};
std::vector<int> a_shape = {1, 2, 8};
std::vector<int> b_shape = {1, 8, 3};
std::vector<int> c_shape = {1, 2, 3};
std::vector<int> a_shape = {2, 8};
std::vector<int> b_shape = {8, 3};
std::vector<int> c_shape = {2, 3};
int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
ctx->thread_num_ = 1;
auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
mm->Init();
mm->Run();
float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779,
-0.3049793541431427, -0.027687929570674896, -0.18109679222106934};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete matmul_param;
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
}
TEST_F(TestMatMulFp32, simple2) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
matmul_param->a_transpose_ = false;
matmul_param->b_transpose_ = false;
matmul_param->has_bias_ = false;
float a[25 * 12] = {
1, 4, 10, 2, 3, 10, 4, 6, 5, 6, 9, 5, 4, 2, 5, 7, 5, 8, 0, 5, 1, 0, 10, 3, 0, 4, 2, 3, 2, 9,
8, 9, 5, 4, 4, 9, 7, 4, 2, 6, 10, 2, 1, 7, 2, 10, 5, 10, 1, 2, 2, 9, 8, 8, 2, 5, 6, 3, 2, 8,
3, 3, 7, 3, 0, 4, 10, 9, 0, 5, 2, 6, 1, 10, 7, 6, 9, 6, 0, 3, 8, 0, 8, 3, 10, 4, 7, 7, 0, 5,
6, 5, 4, 6, 5, 5, 3, 7, 1, 9, 3, 2, 8, 3, 0, 0, 6, 7, 6, 3, 6, 5, 1, 0, 4, 2, 6, 0, 7, 7,
7, 4, 9, 8, 6, 1, 10, 10, 7, 3, 0, 6, 9, 4, 1, 4, 4, 3, 1, 6, 7, 3, 8, 6, 4, 10, 9, 8, 10, 5,
2, 3, 8, 10, 0, 8, 2, 9, 5, 3, 3, 0, 1, 8, 1, 1, 2, 0, 1, 5, 5, 0, 1, 10, 9, 9, 3, 6, 7, 1,
2, 3, 7, 5, 0, 8, 2, 8, 7, 8, 9, 10, 4, 2, 5, 3, 10, 1, 5, 0, 6, 2, 3, 5, 5, 1, 5, 5, 5, 1,
8, 2, 6, 9, 10, 4, 9, 1, 10, 9, 8, 2, 5, 2, 4, 2, 3, 7, 7, 2, 9, 10, 10, 10, 5, 1, 8, 8, 10, 3,
2, 10, 2, 6, 5, 9, 10, 6, 10, 0, 5, 5, 4, 0, 9, 4, 4, 9, 4, 6, 4, 2, 5, 2, 10, 5, 9, 8, 1, 4,
7, 9, 6, 5, 0, 3, 6, 4, 3, 10, 6, 4, 10, 5, 8, 8, 9, 4, 5, 6, 8, 9, 2, 2, 4, 4, 8, 0, 4, 5};
float b[12 * 36] = {
6, 6, 7, 2, 1, 10, 3, 7, 7, 5, 5, 5, 6, 6, 9, 8, 4, 10, 9, 5, 5, 7, 2, 1, 7, 9, 10, 0, 3,
10, 4, 2, 7, 4, 3, 10, 5, 3, 1, 3, 3, 1, 9, 6, 7, 6, 6, 6, 7, 6, 10, 8, 2, 8, 5, 2, 1, 7,
5, 9, 10, 9, 0, 8, 10, 2, 3, 4, 0, 7, 5, 5, 0, 9, 6, 1, 6, 7, 4, 1, 0, 3, 0, 7, 3, 0, 10,
7, 6, 4, 10, 7, 6, 5, 10, 2, 10, 9, 10, 6, 9, 10, 8, 8, 5, 3, 9, 10, 8, 3, 3, 4, 6, 2, 6, 0,
4, 0, 3, 4, 1, 0, 3, 10, 5, 4, 0, 2, 3, 2, 4, 3, 10, 5, 4, 10, 8, 2, 0, 4, 0, 5, 8, 0, 1,
10, 0, 3, 1, 1, 9, 4, 0, 3, 0, 1, 6, 3, 10, 0, 10, 3, 3, 0, 6, 7, 3, 2, 3, 5, 10, 6, 1, 5,
7, 3, 3, 1, 1, 10, 5, 4, 0, 8, 4, 0, 9, 6, 2, 3, 6, 10, 10, 0, 2, 2, 1, 2, 7, 10, 9, 7, 10,
2, 8, 5, 3, 7, 0, 4, 3, 4, 8, 3, 8, 0, 5, 5, 6, 9, 10, 0, 1, 5, 6, 6, 4, 7, 7, 6, 7, 9,
5, 5, 6, 0, 4, 1, 2, 6, 8, 4, 10, 4, 10, 9, 8, 8, 1, 7, 1, 8, 1, 0, 10, 8, 8, 1, 8, 0, 10,
3, 1, 7, 0, 10, 5, 0, 2, 8, 4, 1, 8, 1, 6, 7, 1, 8, 3, 4, 3, 4, 7, 0, 9, 1, 1, 4, 8, 10,
0, 3, 3, 2, 7, 9, 3, 3, 10, 10, 9, 4, 4, 0, 7, 1, 1, 3, 5, 1, 4, 8, 5, 7, 3, 9, 10, 1, 5,
9, 7, 4, 10, 10, 3, 4, 3, 5, 1, 10, 5, 2, 3, 3, 0, 3, 1, 2, 8, 7, 4, 2, 0, 8, 7, 6, 6, 6,
5, 7, 5, 5, 3, 0, 4, 10, 1, 7, 8, 9, 6, 7, 0, 1, 9, 3, 1, 6, 8, 4, 9, 0, 3, 2, 4, 0, 2,
7, 2, 2, 8, 0, 4, 1, 3, 2, 6, 8, 5, 5, 2, 3, 9, 0, 1, 7, 6, 9, 1, 10, 4, 10, 5, 10, 0, 9,
5, 1, 6, 2, 9, 9, 8, 8, 10, 8, 1, 6, 5, 8, 8, 6, 4, 8, 10, 3, 0, 6, 2, 8, 4, 2};
std::vector<int> a_shape = {25, 12};
std::vector<int> b_shape = {12, 36};
std::vector<int> c_shape = {25, 36};
int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
mm->Init();
mm->Run();
float correct[] = {
263, 386, 184, 309, 338, 244, 359, 294, 252, 254, 273, 353, 320, 183, 412, 273, 271, 307, 329, 314, 391, 261, 400,
280, 416, 399, 355, 427, 373, 302, 288, 349, 336, 241, 349, 393, 226, 285, 134, 209, 264, 163, 281, 212, 219, 171,
221, 228, 227, 131, 289, 196, 204, 270, 238, 205, 303, 196, 280, 156, 311, 284, 282, 335, 243, 245, 181, 188, 280,
142, 229, 256, 270, 310, 184, 377, 323, 187, 345, 295, 255, 262, 259, 332, 310, 222, 357, 275, 253, 301, 296, 254,
316, 221, 323, 322, 370, 353, 281, 386, 363, 240, 245, 301, 270, 263, 275, 292, 278, 388, 199, 324, 252, 336, 385,
300, 257, 274, 215, 243, 272, 230, 485, 335, 343, 366, 293, 272, 337, 313, 310, 305, 385, 421, 377, 398, 343, 262,
249, 309, 258, 280, 286, 411, 268, 337, 127, 307, 244, 185, 368, 263, 178, 205, 223, 281, 288, 154, 339, 255, 295,
250, 241, 236, 289, 240, 296, 261, 361, 333, 282, 399, 315, 202, 203, 272, 231, 229, 300, 273, 199, 253, 246, 315,
307, 213, 257, 202, 243, 230, 163, 288, 220, 212, 361, 314, 219, 296, 300, 217, 274, 196, 285, 264, 351, 339, 312,
289, 338, 282, 256, 274, 214, 243, 228, 302, 276, 394, 110, 224, 274, 163, 395, 296, 231, 223, 289, 311, 331, 177,
405, 236, 294, 293, 264, 213, 314, 258, 330, 270, 403, 381, 305, 450, 382, 250, 248, 287, 278, 211, 324, 374, 306,
350, 246, 298, 309, 305, 315, 289, 292, 256, 264, 341, 295, 218, 427, 382, 272, 359, 335, 286, 333, 263, 327, 275,
448, 423, 380, 369, 397, 330, 260, 329, 285, 284, 333, 397, 259, 258, 146, 261, 281, 156, 248, 234, 236, 219, 220,
207, 233, 173, 326, 316, 223, 301, 237, 145, 202, 181, 209, 236, 357, 279, 265, 332, 352, 230, 165, 219, 154, 233,
189, 237, 246, 316, 147, 197, 247, 221, 212, 256, 201, 208, 239, 220, 231, 153, 322, 263, 237, 278, 254, 178, 215,
164, 217, 211, 326, 295, 284, 306, 354, 247, 178, 244, 216, 199, 229, 308, 298, 409, 306, 359, 359, 273, 388, 291,
301, 281, 239, 395, 323, 290, 505, 398, 370, 381, 365, 235, 344, 268, 340, 351, 473, 481, 445, 415, 481, 373, 354,
365, 284, 309, 338, 469, 285, 336, 166, 244, 245, 247, 305, 304, 273, 233, 281, 260, 276, 218, 364, 241, 255, 330,
257, 213, 296, 221, 252, 251, 325, 355, 301, 341, 319, 246, 206, 243, 295, 210, 249, 357, 328, 481, 196, 345, 276,
338, 493, 349, 236, 299, 265, 388, 383, 224, 573, 425, 411, 354, 353, 340, 363, 385, 414, 387, 541, 528, 412, 515,
486, 298, 320, 438, 254, 361, 454, 494, 120, 156, 151, 140, 176, 99, 231, 113, 197, 132, 113, 190, 134, 171, 264,
169, 137, 219, 165, 92, 172, 145, 188, 186, 225, 260, 166, 216, 225, 161, 173, 134, 147, 130, 152, 218, 226, 273,
205, 314, 331, 157, 311, 242, 289, 228, 238, 346, 285, 223, 344, 235, 194, 282, 274, 238, 358, 207, 333, 270, 345,
345, 302, 339, 309, 273, 284, 291, 297, 219, 261, 338, 319, 396, 200, 356, 349, 311, 377, 330, 280, 280, 308, 351,
311, 204, 421, 319, 294, 348, 328, 346, 387, 261, 403, 335, 434, 428, 333, 467, 422, 270, 254, 370, 345, 285, 381,
378, 200, 347, 110, 195, 189, 184, 252, 242, 134, 191, 179, 205, 256, 140, 349, 219, 287, 216, 225, 155, 223, 203,
203, 196, 295, 281, 321, 291, 292, 235, 219, 255, 177, 186, 213, 349, 286, 389, 180, 262, 306, 275, 269, 284, 257,
239, 256, 262, 270, 189, 410, 306, 302, 297, 244, 226, 335, 213, 276, 257, 371, 351, 398, 376, 378, 289, 265, 355,
258, 252, 286, 446, 274, 419, 214, 263, 277, 296, 317, 276, 202, 240, 214, 287, 292, 174, 454, 366, 352, 328, 342,
247, 300, 273, 300, 232, 440, 401, 436, 374, 394, 351, 269, 317, 247, 255, 312, 416, 384, 533, 202, 336, 369, 322,
449, 373, 291, 282, 343, 409, 416, 198, 526, 383, 405, 363, 355, 355, 478, 348, 435, 296, 544, 490, 519, 540, 449,
390, 345, 444, 378, 307, 454, 542, 356, 394, 179, 370, 364, 152, 424, 370, 316, 291, 358, 420, 419, 267, 429, 323,
311, 348, 320, 232, 344, 260, 344, 369, 472, 424, 339, 479, 470, 297, 298, 350, 300, 302, 340, 389, 211, 314, 186,
248, 277, 184, 294, 217, 204, 184, 203, 311, 262, 154, 324, 221, 233, 249, 283, 241, 331, 210, 318, 191, 341, 330,
331, 323, 278, 289, 255, 259, 294, 174, 280, 323, 295, 348, 303, 319, 321, 286, 365, 266, 310, 251, 240, 406, 302,
265, 457, 396, 297, 366, 350, 270, 343, 271, 347, 314, 469, 476, 396, 375, 428, 351, 315, 341, 291, 296, 361, 428,
383, 442, 232, 360, 387, 279, 391, 349, 348, 288, 334, 374, 360, 262, 485, 391, 362, 379, 296, 262, 406, 270, 346,
346, 486, 451, 451, 490, 475, 339, 319, 409, 315, 324, 367, 493, 286, 348, 185, 240, 287, 214, 312, 265, 237, 218,
261, 316, 279, 186, 377, 319, 279, 304, 281, 207, 261, 209, 287, 270, 415, 378, 312, 388, 423, 273, 230, 294, 239,
243, 319, 346};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
@ -243,7 +332,6 @@ TEST_F(TestMatMulFp32, simple_transb) {
mm->Run();
float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete matmul_param;
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
@ -298,9 +386,7 @@ TEST_F(TestMatMulFp32, batch) {
8.869029998779297, 25.034008026123047};
float *output = reinterpret_cast<float *>(outputs_[0]->Data());
for (int i = 0; i < 18; ++i) printf("%f ", output[i]);
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete matmul_param;
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;

Loading…
Cancel
Save