|
|
|
@ -469,111 +469,111 @@ void TestNCHW16CMulNCKernel() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// XYZNTuple
|
|
|
|
|
TEST(JITKernel, vmul) {
|
|
|
|
|
TEST(JITKernel, kVMul) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVMul, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVMul, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vadd) {
|
|
|
|
|
TEST(JITKernel, kVAdd) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYZNKernel<jit::vadd, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::vadd, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVAdd, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVAdd, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vaddrelu) {
|
|
|
|
|
TEST(JITKernel, kVAddRelu) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYZNKernel<jit::vaddrelu, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::vaddrelu, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVAddRelu, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVAddRelu, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vsub) {
|
|
|
|
|
TEST(JITKernel, kVSub) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYZNKernel<jit::vsub, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVSub, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYZNKernel<jit::kVSub, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// AXYNTuples
|
|
|
|
|
TEST(JITKernel, vscal) {
|
|
|
|
|
TEST(JITKernel, kVScal) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestAXYNKernel<jit::vscal, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::vscal, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::kVScal, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::kVScal, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vaddbias) {
|
|
|
|
|
TEST(JITKernel, kVAddBias) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestAXYNKernel<jit::vaddbias, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::vaddbias, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::kVAddBias, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::kVAddBias, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// XYNTuples
|
|
|
|
|
TEST(JITKernel, vrelu) {
|
|
|
|
|
TEST(JITKernel, kVRelu) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vrelu, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vrelu, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVRelu, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVRelu, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, videntity) {
|
|
|
|
|
TEST(JITKernel, kVIdentity) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::videntity, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::videntity, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVIdentity, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vexp) {
|
|
|
|
|
TEST(JITKernel, kVExp) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vexp, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vexp, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVExp, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vsigmoid) {
|
|
|
|
|
TEST(JITKernel, kVSigmoid) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vsigmoid, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vsigmoid, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVSigmoid, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVSigmoid, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vtanh) {
|
|
|
|
|
TEST(JITKernel, kVTanh) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestXYNKernel<jit::vtanh, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVTanh, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestXYNKernel<jit::kVTanh, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// LSTM
|
|
|
|
|
TEST(JITKernel, lstmctht) {
|
|
|
|
|
TEST(JITKernel, kLSTMCtHt) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::kLSTMCtHt, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::kLSTMCtHt, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, lstmc1h1) {
|
|
|
|
|
TEST(JITKernel, kLSTMC1H1) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::kLSTMC1H1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestLSTMKernel<jit::kLSTMC1H1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GRU
|
|
|
|
|
TEST(JITKernel, gruh1) {
|
|
|
|
|
TEST(JITKernel, kGRUH1) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestGRUKernel<jit::gruh1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::gruh1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::kGRUH1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::kGRUH1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, gruhtpart1) {
|
|
|
|
|
TEST(JITKernel, kGRUHtPart1) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestGRUKernel<jit::gruhtpart1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::gruhtpart1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::kGRUHtPart1, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::kGRUHtPart1, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, gruhtpart2) {
|
|
|
|
|
TEST(JITKernel, kGRUHtPart2) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestGRUKernel<jit::gruhtpart2, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::kGRUHtPart2, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, nchw16cmulnc) {
|
|
|
|
|
TEST(JITKernel, kNCHW16CMulNC) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, float,
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
|
|
|
|
|
paddle::platform::CPUPlace>();
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, double,
|
|
|
|
|
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double,
|
|
|
|
|
paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|