|
|
|
@ -201,19 +201,108 @@ TEST_F(TestMatMulFp32, simple) {
|
|
|
|
|
0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552,
|
|
|
|
|
-0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459,
|
|
|
|
|
0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343};
|
|
|
|
|
std::vector<int> a_shape = {1, 2, 8};
|
|
|
|
|
std::vector<int> b_shape = {1, 8, 3};
|
|
|
|
|
std::vector<int> c_shape = {1, 2, 3};
|
|
|
|
|
std::vector<int> a_shape = {2, 8};
|
|
|
|
|
std::vector<int> b_shape = {8, 3};
|
|
|
|
|
std::vector<int> c_shape = {2, 3};
|
|
|
|
|
int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
|
|
|
|
|
auto ctx = new lite::Context;
|
|
|
|
|
ctx->thread_num_ = 2;
|
|
|
|
|
ctx->thread_num_ = 1;
|
|
|
|
|
auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
|
|
|
|
|
mm->Init();
|
|
|
|
|
mm->Run();
|
|
|
|
|
float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779,
|
|
|
|
|
-0.3049793541431427, -0.027687929570674896, -0.18109679222106934};
|
|
|
|
|
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
|
|
|
|
|
delete matmul_param;
|
|
|
|
|
delete mm;
|
|
|
|
|
for (auto t : inputs_) delete t;
|
|
|
|
|
for (auto t : outputs_) delete t;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, simple2) {
|
|
|
|
|
std::vector<lite::tensor::Tensor *> inputs_;
|
|
|
|
|
std::vector<lite::tensor::Tensor *> outputs_;
|
|
|
|
|
auto matmul_param = new MatMulParameter();
|
|
|
|
|
matmul_param->a_transpose_ = false;
|
|
|
|
|
matmul_param->b_transpose_ = false;
|
|
|
|
|
matmul_param->has_bias_ = false;
|
|
|
|
|
float a[25 * 12] = {
|
|
|
|
|
1, 4, 10, 2, 3, 10, 4, 6, 5, 6, 9, 5, 4, 2, 5, 7, 5, 8, 0, 5, 1, 0, 10, 3, 0, 4, 2, 3, 2, 9,
|
|
|
|
|
8, 9, 5, 4, 4, 9, 7, 4, 2, 6, 10, 2, 1, 7, 2, 10, 5, 10, 1, 2, 2, 9, 8, 8, 2, 5, 6, 3, 2, 8,
|
|
|
|
|
3, 3, 7, 3, 0, 4, 10, 9, 0, 5, 2, 6, 1, 10, 7, 6, 9, 6, 0, 3, 8, 0, 8, 3, 10, 4, 7, 7, 0, 5,
|
|
|
|
|
6, 5, 4, 6, 5, 5, 3, 7, 1, 9, 3, 2, 8, 3, 0, 0, 6, 7, 6, 3, 6, 5, 1, 0, 4, 2, 6, 0, 7, 7,
|
|
|
|
|
7, 4, 9, 8, 6, 1, 10, 10, 7, 3, 0, 6, 9, 4, 1, 4, 4, 3, 1, 6, 7, 3, 8, 6, 4, 10, 9, 8, 10, 5,
|
|
|
|
|
2, 3, 8, 10, 0, 8, 2, 9, 5, 3, 3, 0, 1, 8, 1, 1, 2, 0, 1, 5, 5, 0, 1, 10, 9, 9, 3, 6, 7, 1,
|
|
|
|
|
2, 3, 7, 5, 0, 8, 2, 8, 7, 8, 9, 10, 4, 2, 5, 3, 10, 1, 5, 0, 6, 2, 3, 5, 5, 1, 5, 5, 5, 1,
|
|
|
|
|
8, 2, 6, 9, 10, 4, 9, 1, 10, 9, 8, 2, 5, 2, 4, 2, 3, 7, 7, 2, 9, 10, 10, 10, 5, 1, 8, 8, 10, 3,
|
|
|
|
|
2, 10, 2, 6, 5, 9, 10, 6, 10, 0, 5, 5, 4, 0, 9, 4, 4, 9, 4, 6, 4, 2, 5, 2, 10, 5, 9, 8, 1, 4,
|
|
|
|
|
7, 9, 6, 5, 0, 3, 6, 4, 3, 10, 6, 4, 10, 5, 8, 8, 9, 4, 5, 6, 8, 9, 2, 2, 4, 4, 8, 0, 4, 5};
|
|
|
|
|
float b[12 * 36] = {
|
|
|
|
|
6, 6, 7, 2, 1, 10, 3, 7, 7, 5, 5, 5, 6, 6, 9, 8, 4, 10, 9, 5, 5, 7, 2, 1, 7, 9, 10, 0, 3,
|
|
|
|
|
10, 4, 2, 7, 4, 3, 10, 5, 3, 1, 3, 3, 1, 9, 6, 7, 6, 6, 6, 7, 6, 10, 8, 2, 8, 5, 2, 1, 7,
|
|
|
|
|
5, 9, 10, 9, 0, 8, 10, 2, 3, 4, 0, 7, 5, 5, 0, 9, 6, 1, 6, 7, 4, 1, 0, 3, 0, 7, 3, 0, 10,
|
|
|
|
|
7, 6, 4, 10, 7, 6, 5, 10, 2, 10, 9, 10, 6, 9, 10, 8, 8, 5, 3, 9, 10, 8, 3, 3, 4, 6, 2, 6, 0,
|
|
|
|
|
4, 0, 3, 4, 1, 0, 3, 10, 5, 4, 0, 2, 3, 2, 4, 3, 10, 5, 4, 10, 8, 2, 0, 4, 0, 5, 8, 0, 1,
|
|
|
|
|
10, 0, 3, 1, 1, 9, 4, 0, 3, 0, 1, 6, 3, 10, 0, 10, 3, 3, 0, 6, 7, 3, 2, 3, 5, 10, 6, 1, 5,
|
|
|
|
|
7, 3, 3, 1, 1, 10, 5, 4, 0, 8, 4, 0, 9, 6, 2, 3, 6, 10, 10, 0, 2, 2, 1, 2, 7, 10, 9, 7, 10,
|
|
|
|
|
2, 8, 5, 3, 7, 0, 4, 3, 4, 8, 3, 8, 0, 5, 5, 6, 9, 10, 0, 1, 5, 6, 6, 4, 7, 7, 6, 7, 9,
|
|
|
|
|
5, 5, 6, 0, 4, 1, 2, 6, 8, 4, 10, 4, 10, 9, 8, 8, 1, 7, 1, 8, 1, 0, 10, 8, 8, 1, 8, 0, 10,
|
|
|
|
|
3, 1, 7, 0, 10, 5, 0, 2, 8, 4, 1, 8, 1, 6, 7, 1, 8, 3, 4, 3, 4, 7, 0, 9, 1, 1, 4, 8, 10,
|
|
|
|
|
0, 3, 3, 2, 7, 9, 3, 3, 10, 10, 9, 4, 4, 0, 7, 1, 1, 3, 5, 1, 4, 8, 5, 7, 3, 9, 10, 1, 5,
|
|
|
|
|
9, 7, 4, 10, 10, 3, 4, 3, 5, 1, 10, 5, 2, 3, 3, 0, 3, 1, 2, 8, 7, 4, 2, 0, 8, 7, 6, 6, 6,
|
|
|
|
|
5, 7, 5, 5, 3, 0, 4, 10, 1, 7, 8, 9, 6, 7, 0, 1, 9, 3, 1, 6, 8, 4, 9, 0, 3, 2, 4, 0, 2,
|
|
|
|
|
7, 2, 2, 8, 0, 4, 1, 3, 2, 6, 8, 5, 5, 2, 3, 9, 0, 1, 7, 6, 9, 1, 10, 4, 10, 5, 10, 0, 9,
|
|
|
|
|
5, 1, 6, 2, 9, 9, 8, 8, 10, 8, 1, 6, 5, 8, 8, 6, 4, 8, 10, 3, 0, 6, 2, 8, 4, 2};
|
|
|
|
|
std::vector<int> a_shape = {25, 12};
|
|
|
|
|
std::vector<int> b_shape = {12, 36};
|
|
|
|
|
std::vector<int> c_shape = {25, 36};
|
|
|
|
|
int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
|
|
|
|
|
auto ctx = new lite::Context;
|
|
|
|
|
ctx->thread_num_ = 2;
|
|
|
|
|
auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
|
|
|
|
|
mm->Init();
|
|
|
|
|
mm->Run();
|
|
|
|
|
float correct[] = {
|
|
|
|
|
263, 386, 184, 309, 338, 244, 359, 294, 252, 254, 273, 353, 320, 183, 412, 273, 271, 307, 329, 314, 391, 261, 400,
|
|
|
|
|
280, 416, 399, 355, 427, 373, 302, 288, 349, 336, 241, 349, 393, 226, 285, 134, 209, 264, 163, 281, 212, 219, 171,
|
|
|
|
|
221, 228, 227, 131, 289, 196, 204, 270, 238, 205, 303, 196, 280, 156, 311, 284, 282, 335, 243, 245, 181, 188, 280,
|
|
|
|
|
142, 229, 256, 270, 310, 184, 377, 323, 187, 345, 295, 255, 262, 259, 332, 310, 222, 357, 275, 253, 301, 296, 254,
|
|
|
|
|
316, 221, 323, 322, 370, 353, 281, 386, 363, 240, 245, 301, 270, 263, 275, 292, 278, 388, 199, 324, 252, 336, 385,
|
|
|
|
|
300, 257, 274, 215, 243, 272, 230, 485, 335, 343, 366, 293, 272, 337, 313, 310, 305, 385, 421, 377, 398, 343, 262,
|
|
|
|
|
249, 309, 258, 280, 286, 411, 268, 337, 127, 307, 244, 185, 368, 263, 178, 205, 223, 281, 288, 154, 339, 255, 295,
|
|
|
|
|
250, 241, 236, 289, 240, 296, 261, 361, 333, 282, 399, 315, 202, 203, 272, 231, 229, 300, 273, 199, 253, 246, 315,
|
|
|
|
|
307, 213, 257, 202, 243, 230, 163, 288, 220, 212, 361, 314, 219, 296, 300, 217, 274, 196, 285, 264, 351, 339, 312,
|
|
|
|
|
289, 338, 282, 256, 274, 214, 243, 228, 302, 276, 394, 110, 224, 274, 163, 395, 296, 231, 223, 289, 311, 331, 177,
|
|
|
|
|
405, 236, 294, 293, 264, 213, 314, 258, 330, 270, 403, 381, 305, 450, 382, 250, 248, 287, 278, 211, 324, 374, 306,
|
|
|
|
|
350, 246, 298, 309, 305, 315, 289, 292, 256, 264, 341, 295, 218, 427, 382, 272, 359, 335, 286, 333, 263, 327, 275,
|
|
|
|
|
448, 423, 380, 369, 397, 330, 260, 329, 285, 284, 333, 397, 259, 258, 146, 261, 281, 156, 248, 234, 236, 219, 220,
|
|
|
|
|
207, 233, 173, 326, 316, 223, 301, 237, 145, 202, 181, 209, 236, 357, 279, 265, 332, 352, 230, 165, 219, 154, 233,
|
|
|
|
|
189, 237, 246, 316, 147, 197, 247, 221, 212, 256, 201, 208, 239, 220, 231, 153, 322, 263, 237, 278, 254, 178, 215,
|
|
|
|
|
164, 217, 211, 326, 295, 284, 306, 354, 247, 178, 244, 216, 199, 229, 308, 298, 409, 306, 359, 359, 273, 388, 291,
|
|
|
|
|
301, 281, 239, 395, 323, 290, 505, 398, 370, 381, 365, 235, 344, 268, 340, 351, 473, 481, 445, 415, 481, 373, 354,
|
|
|
|
|
365, 284, 309, 338, 469, 285, 336, 166, 244, 245, 247, 305, 304, 273, 233, 281, 260, 276, 218, 364, 241, 255, 330,
|
|
|
|
|
257, 213, 296, 221, 252, 251, 325, 355, 301, 341, 319, 246, 206, 243, 295, 210, 249, 357, 328, 481, 196, 345, 276,
|
|
|
|
|
338, 493, 349, 236, 299, 265, 388, 383, 224, 573, 425, 411, 354, 353, 340, 363, 385, 414, 387, 541, 528, 412, 515,
|
|
|
|
|
486, 298, 320, 438, 254, 361, 454, 494, 120, 156, 151, 140, 176, 99, 231, 113, 197, 132, 113, 190, 134, 171, 264,
|
|
|
|
|
169, 137, 219, 165, 92, 172, 145, 188, 186, 225, 260, 166, 216, 225, 161, 173, 134, 147, 130, 152, 218, 226, 273,
|
|
|
|
|
205, 314, 331, 157, 311, 242, 289, 228, 238, 346, 285, 223, 344, 235, 194, 282, 274, 238, 358, 207, 333, 270, 345,
|
|
|
|
|
345, 302, 339, 309, 273, 284, 291, 297, 219, 261, 338, 319, 396, 200, 356, 349, 311, 377, 330, 280, 280, 308, 351,
|
|
|
|
|
311, 204, 421, 319, 294, 348, 328, 346, 387, 261, 403, 335, 434, 428, 333, 467, 422, 270, 254, 370, 345, 285, 381,
|
|
|
|
|
378, 200, 347, 110, 195, 189, 184, 252, 242, 134, 191, 179, 205, 256, 140, 349, 219, 287, 216, 225, 155, 223, 203,
|
|
|
|
|
203, 196, 295, 281, 321, 291, 292, 235, 219, 255, 177, 186, 213, 349, 286, 389, 180, 262, 306, 275, 269, 284, 257,
|
|
|
|
|
239, 256, 262, 270, 189, 410, 306, 302, 297, 244, 226, 335, 213, 276, 257, 371, 351, 398, 376, 378, 289, 265, 355,
|
|
|
|
|
258, 252, 286, 446, 274, 419, 214, 263, 277, 296, 317, 276, 202, 240, 214, 287, 292, 174, 454, 366, 352, 328, 342,
|
|
|
|
|
247, 300, 273, 300, 232, 440, 401, 436, 374, 394, 351, 269, 317, 247, 255, 312, 416, 384, 533, 202, 336, 369, 322,
|
|
|
|
|
449, 373, 291, 282, 343, 409, 416, 198, 526, 383, 405, 363, 355, 355, 478, 348, 435, 296, 544, 490, 519, 540, 449,
|
|
|
|
|
390, 345, 444, 378, 307, 454, 542, 356, 394, 179, 370, 364, 152, 424, 370, 316, 291, 358, 420, 419, 267, 429, 323,
|
|
|
|
|
311, 348, 320, 232, 344, 260, 344, 369, 472, 424, 339, 479, 470, 297, 298, 350, 300, 302, 340, 389, 211, 314, 186,
|
|
|
|
|
248, 277, 184, 294, 217, 204, 184, 203, 311, 262, 154, 324, 221, 233, 249, 283, 241, 331, 210, 318, 191, 341, 330,
|
|
|
|
|
331, 323, 278, 289, 255, 259, 294, 174, 280, 323, 295, 348, 303, 319, 321, 286, 365, 266, 310, 251, 240, 406, 302,
|
|
|
|
|
265, 457, 396, 297, 366, 350, 270, 343, 271, 347, 314, 469, 476, 396, 375, 428, 351, 315, 341, 291, 296, 361, 428,
|
|
|
|
|
383, 442, 232, 360, 387, 279, 391, 349, 348, 288, 334, 374, 360, 262, 485, 391, 362, 379, 296, 262, 406, 270, 346,
|
|
|
|
|
346, 486, 451, 451, 490, 475, 339, 319, 409, 315, 324, 367, 493, 286, 348, 185, 240, 287, 214, 312, 265, 237, 218,
|
|
|
|
|
261, 316, 279, 186, 377, 319, 279, 304, 281, 207, 261, 209, 287, 270, 415, 378, 312, 388, 423, 273, 230, 294, 239,
|
|
|
|
|
243, 319, 346};
|
|
|
|
|
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
|
|
|
|
|
delete mm;
|
|
|
|
|
for (auto t : inputs_) delete t;
|
|
|
|
|
for (auto t : outputs_) delete t;
|
|
|
|
@ -243,7 +332,6 @@ TEST_F(TestMatMulFp32, simple_transb) {
|
|
|
|
|
mm->Run();
|
|
|
|
|
float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031};
|
|
|
|
|
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
|
|
|
|
|
delete matmul_param;
|
|
|
|
|
delete mm;
|
|
|
|
|
for (auto t : inputs_) delete t;
|
|
|
|
|
for (auto t : outputs_) delete t;
|
|
|
|
@ -298,9 +386,7 @@ TEST_F(TestMatMulFp32, batch) {
|
|
|
|
|
8.869029998779297, 25.034008026123047};
|
|
|
|
|
|
|
|
|
|
float *output = reinterpret_cast<float *>(outputs_[0]->Data());
|
|
|
|
|
for (int i = 0; i < 18; ++i) printf("%f ", output[i]);
|
|
|
|
|
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
|
|
|
|
|
delete matmul_param;
|
|
|
|
|
delete mm;
|
|
|
|
|
for (auto t : inputs_) delete t;
|
|
|
|
|
for (auto t : outputs_) delete t;
|
|
|
|
|