|
|
|
@ -19,6 +19,7 @@
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernels.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_info.h"
|
|
|
|
|
#include "paddle/fluid/platform/place.h"
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -414,6 +415,59 @@ void TestGRUKernel() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestNCHW16CMulNCKernel() {
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
const int n = 3, c = 16 * 4, h = 10, w = 10;
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
int sz = n * c * h * w;
|
|
|
|
|
std::vector<T> x(sz), y(n * c), zref(sz);
|
|
|
|
|
std::vector<T> ztgt(sz), zjit(sz);
|
|
|
|
|
RandomVec<T>(sz, x.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(n * c, y.data(), -2.f, 2.f);
|
|
|
|
|
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
const T* y_data = y.data();
|
|
|
|
|
T* zref_data = zref.data();
|
|
|
|
|
T* ztgt_data = ztgt.data();
|
|
|
|
|
T* zjit_data = zjit.data();
|
|
|
|
|
constexpr int simd_width = ZMM_FLOAT_BLOCK;
|
|
|
|
|
int C = c / simd_width;
|
|
|
|
|
auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
|
|
|
|
|
if (std::is_same<T, float>::value &&
|
|
|
|
|
paddle::platform::MayIUse(paddle::platform::avx512f)) {
|
|
|
|
|
EXPECT_TRUE(jitcode != nullptr);
|
|
|
|
|
}
|
|
|
|
|
for (int ni = 0; ni < n; ni++) {
|
|
|
|
|
for (int ci = 0; ci < C; ci++) {
|
|
|
|
|
auto ptr_x =
|
|
|
|
|
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
|
|
|
|
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
|
|
|
|
|
auto ptr_zref =
|
|
|
|
|
zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
|
|
|
|
auto ptr_ztgt =
|
|
|
|
|
ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
|
|
|
|
|
|
|
|
|
ref(ptr_x, ptr_y, ptr_zref, h, w);
|
|
|
|
|
tgt(ptr_x, ptr_y, ptr_ztgt, h, w);
|
|
|
|
|
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
auto ptr_zjit =
|
|
|
|
|
zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
|
|
|
|
jitcode(ptr_x, ptr_y, ptr_zjit, h, w);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ExpectEQ<T>(ztgt_data, zref_data, sz);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
ExpectEQ<T>(zjit_data, zref_data, sz);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// XYZNTuple
|
|
|
|
|
TEST(JITKernel, vmul) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
@ -515,6 +569,14 @@ TEST(JITKernel, gruhtpart2) {
|
|
|
|
|
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, nchw16cmulnc) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, float,
|
|
|
|
|
paddle::platform::CPUPlace>();
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, double,
|
|
|
|
|
paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(yihua/TJ): add crf decoding and layer norm unit tests
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, pool) {
|
|
|
|
|