|
|
|
@ -68,117 +68,6 @@ TEST_F(TestMatMulFp32, Row2Col8Test2) {
|
|
|
|
|
CompareOutputData(out, co, 120, 0.0001);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest1) {
|
|
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0,
|
|
|
|
|
0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0,
|
|
|
|
|
0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0,
|
|
|
|
|
0.09, 0.93, 0.91, 0.20, 0.97, 0, 0, 0, 0.61, 0.43, 0.14, 0.67, 0.10, 0, 0, 0,
|
|
|
|
|
0.73, 0.37, 0.24, 0.93, 0.31, 0, 0, 0, 0.35, 0.52, 0.02, 0.33, 0.99, 0, 0, 0,
|
|
|
|
|
0.49, 0.67, 0.75, 0.66, 0.04, 0, 0, 0, 0.10, 0.18, 0.92, 0.46, 0.08, 0, 0, 0,
|
|
|
|
|
0.04, 0.24, 0.52, 0.43, 0.14, 0, 0, 0, 0.67, 0.10, 0.73, 0.37, 0.24, 0, 0, 0,
|
|
|
|
|
0.93, 0.31, 0.35, 0.52, 0.02, 0, 0, 0, 0.33, 0.99, 0.49, 0.67, 0.75, 0, 0, 0,
|
|
|
|
|
0.66, 0.04, 0.10, 0.18, 0.92, 0, 0, 0, 0.46, 0.08, 0.04, 0.24, 0.52, 0, 0, 0,
|
|
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
|
|
|
|
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90,
|
|
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39,
|
|
|
|
|
0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31,
|
|
|
|
|
0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08,
|
|
|
|
|
0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02,
|
|
|
|
|
0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52};
|
|
|
|
|
float out[90] = {0};
|
|
|
|
|
Row8x8Major2RowMajor(in, out, 18, 5, 5);
|
|
|
|
|
CompareOutputData(out, co, 90, 0.0001);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest2) {
|
|
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0,
|
|
|
|
|
0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0,
|
|
|
|
|
0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0,
|
|
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
|
|
|
|
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90,
|
|
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39};
|
|
|
|
|
float out[30] = {0};
|
|
|
|
|
Row8x8Major2RowMajor(in, out, 6, 5, 5);
|
|
|
|
|
CompareOutputData(out, co, 30, 0.0001);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest3) {
|
|
|
|
|
float in[] = {
|
|
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73,
|
|
|
|
|
0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04,
|
|
|
|
|
0.10, 0.18, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.93,
|
|
|
|
|
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.75, 0.66, 0.04, 0.10,
|
|
|
|
|
0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02,
|
|
|
|
|
0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46,
|
|
|
|
|
0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50,
|
|
|
|
|
0.39, 0.09, 0.93, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24,
|
|
|
|
|
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81,
|
|
|
|
|
0.98, 0.09, 0.68, 0.02, 0.33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
0, 0, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52,
|
|
|
|
|
0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04,
|
|
|
|
|
0.24, 0.52, 0.21, 0.38, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67,
|
|
|
|
|
0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.04, 0.24,
|
|
|
|
|
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85, 0.67, 0.81, 0.57, 0.70,
|
|
|
|
|
0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66,
|
|
|
|
|
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97,
|
|
|
|
|
0.61, 0.43, 0.14, 0.67, 0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33,
|
|
|
|
|
0.99, 0.49, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85,
|
|
|
|
|
0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
0, 0, 0, 0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0,
|
|
|
|
|
0, 0.04, 0.10, 0.18, 0, 0, 0, 0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.81, 0.98,
|
|
|
|
|
0.09, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73, 0.37, 0.24, 0, 0,
|
|
|
|
|
0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0, 0, 0, 0, 0,
|
|
|
|
|
0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0, 0, 0.13, 0.03, 0.53,
|
|
|
|
|
0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0, 0, 0.04, 0.10, 0.18, 0, 0, 0,
|
|
|
|
|
0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73,
|
|
|
|
|
0.37, 0.24, 0, 0, 0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0,
|
|
|
|
|
0, 0, 0, 0, 0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0,
|
|
|
|
|
0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
0, 0, 0, 0, 0, 0};
|
|
|
|
|
float co[] = {
|
|
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53,
|
|
|
|
|
0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14,
|
|
|
|
|
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18,
|
|
|
|
|
0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33,
|
|
|
|
|
0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.21, 0.38, 0.81, 0.98, 0.09,
|
|
|
|
|
0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78,
|
|
|
|
|
0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24,
|
|
|
|
|
0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24,
|
|
|
|
|
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67,
|
|
|
|
|
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93,
|
|
|
|
|
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52,
|
|
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53,
|
|
|
|
|
0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14,
|
|
|
|
|
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18,
|
|
|
|
|
0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33,
|
|
|
|
|
0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78,
|
|
|
|
|
0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24,
|
|
|
|
|
0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24,
|
|
|
|
|
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67,
|
|
|
|
|
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93,
|
|
|
|
|
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52,
|
|
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53};
|
|
|
|
|
float out[418] = {0};
|
|
|
|
|
Row8x8Major2RowMajor(in, out, 22, 19, 19);
|
|
|
|
|
CompareOutputData(out, co, 418, 0.0001);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest4) {
|
|
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.27,
|
|
|
|
|
0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97,
|
|
|
|
|
0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92,
|
|
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.27, 0.39};
|
|
|
|
|
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.27,
|
|
|
|
|
0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97,
|
|
|
|
|
0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92,
|
|
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.27, 0.39};
|
|
|
|
|
float out[64] = {0};
|
|
|
|
|
Row8x8Major2RowMajor(in, out, 8, 8, 8);
|
|
|
|
|
CompareOutputData(out, co, 64, 0.0001);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MMTestInit(std::vector<lite::Tensor *> *inputs_, std::vector<lite::Tensor *> *outputs_, float *a_ptr, float *b_ptr,
|
|
|
|
|
std::vector<int> a_shape, std::vector<int> b_shape, std::vector<int> c_shape) {
|
|
|
|
|
auto in_t = new lite::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, lite::Tensor::Category::CONST);
|
|
|
|
|