|
|
|
@ -712,6 +712,63 @@ TEST(JitKernel, vadd) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
z[i] = x[i] + y[i];
|
|
|
|
|
z[i] = z[i] > 0 ? z[i] : 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void vaddrelu_better(
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
|
|
|
|
|
const std::shared_ptr<
|
|
|
|
|
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
|
|
|
|
|
const float* x, const float* y, float* z) {
|
|
|
|
|
vadd->Compute(x, y, z);
|
|
|
|
|
vrelu->Compute(z, z);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JitKernel, vaddrelu) {
|
|
|
|
|
namespace jit = paddle::operators::math::jitkernel;
|
|
|
|
|
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
|
|
|
|
|
std::vector<float> x(d), y(d);
|
|
|
|
|
std::vector<float> zref(d), ztgt(d);
|
|
|
|
|
RandomVec<float>(d, x.data());
|
|
|
|
|
RandomVec<float>(d, y.data());
|
|
|
|
|
const auto& ker =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VAddReluKernel<float>>(d);
|
|
|
|
|
const auto& vadd =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
|
|
|
|
|
const auto& vrelu =
|
|
|
|
|
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
|
|
|
|
|
const float* x_data = x.data();
|
|
|
|
|
const float* y_data = y.data();
|
|
|
|
|
float* ztgt_data = ztgt.data();
|
|
|
|
|
float* zref_data = zref.data();
|
|
|
|
|
auto trefs = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
vadd_ref(d, x_data, y_data, zref_data);
|
|
|
|
|
}
|
|
|
|
|
auto trefe = GetCurrentUS();
|
|
|
|
|
auto tmkls = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
|
|
|
|
|
}
|
|
|
|
|
auto tmkle = GetCurrentUS();
|
|
|
|
|
auto ttgts = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
ker->Compute(x_data, y_data, ztgt_data);
|
|
|
|
|
}
|
|
|
|
|
auto ttgte = GetCurrentUS();
|
|
|
|
|
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
|
|
|
|
|
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
|
|
|
|
|
<< "tgt takes: " << (ttgte - ttgts) / repeat;
|
|
|
|
|
for (int i = 0; i < d; ++i) {
|
|
|
|
|
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JitKernel, pool) {
|
|
|
|
|
namespace jit = paddle::operators::math::jitkernel;
|
|
|
|
|
const int frame_size = 4;
|
|
|
|
|