|
|
@ -462,7 +462,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// test result from Get function
|
|
|
|
// test result from Get function
|
|
|
|
// VLOG(10) << "Test Get function ";
|
|
|
|
// VLOG(10) << "Test Get function ";
|
|
|
|
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
|
|
|
|
auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr);
|
|
|
|
test(tgt, args...);
|
|
|
|
test(tgt, args...);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -845,7 +845,9 @@ void TestKernelNCHW16CMulNCTuples() {
|
|
|
|
T* zjit_data = zjit.data();
|
|
|
|
T* zjit_data = zjit.data();
|
|
|
|
constexpr int simd_width = ZMM_FLOAT_BLOCK;
|
|
|
|
constexpr int simd_width = ZMM_FLOAT_BLOCK;
|
|
|
|
int C = c / simd_width;
|
|
|
|
int C = c / simd_width;
|
|
|
|
auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
|
|
|
|
auto tgt =
|
|
|
|
|
|
|
|
jit::KernelFuncs<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>::Cache().At(
|
|
|
|
|
|
|
|
0);
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
|
|
|
|
|
|
@ -967,10 +969,10 @@ void TestKernelVBroadcastTuples() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
|
|
|
|
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
|
|
|
|
TEST(JITKernel, kernel_type) { \
|
|
|
|
TEST(JITKernel, kernel_type) { \
|
|
|
|
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
|
|
|
|
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
|
|
|
|
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
|
|
|
|
TestKernel##test_tuple<jit::kernel_type, double, CPUPlace>(); \
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_CPU_KERNEL(XYZNTuples, kVMul);
|
|
|
|
TEST_CPU_KERNEL(XYZNTuples, kVMul);
|
|
|
@ -1041,4 +1043,18 @@ TEST(JITKernel_key, gru) {
|
|
|
|
EXPECT_TRUE(key2 == key3);
|
|
|
|
EXPECT_TRUE(key2 == key3);
|
|
|
|
EXPECT_TRUE(key3 != key4);
|
|
|
|
EXPECT_TRUE(key3 != key4);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// TODO(TJ): add more test about key and pool
|
|
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, kernel_func) {
|
|
|
|
|
|
|
|
auto f1 =
|
|
|
|
|
|
|
|
jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
|
|
|
|
|
|
|
|
.At(3);
|
|
|
|
|
|
|
|
auto f2 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>,
|
|
|
|
|
|
|
|
CPUPlace>::Cache()[3];
|
|
|
|
|
|
|
|
EXPECT_TRUE(f1 == f2);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f1 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
|
|
|
|
|
|
|
|
.At(3);
|
|
|
|
|
|
|
|
f2 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
|
|
|
|
|
|
|
|
.At(4);
|
|
|
|
|
|
|
|
EXPECT_TRUE(f1 != f2);
|
|
|
|
|
|
|
|
}
|
|
|
|