|
|
|
@ -12,7 +12,6 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <cstring> // for memcpy
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -59,9 +58,9 @@ std::vector<int> TestSizes() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
void TestTartgetFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const std::vector<T>& x, const std::vector<T>& y,
|
|
|
|
|
const std::vector<T>& zref) {
|
|
|
|
|
void TestXYZNFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const std::vector<T>& x, const std::vector<T>& y,
|
|
|
|
|
const std::vector<T>& zref) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(zref.size(), x.size());
|
|
|
|
|
EXPECT_EQ(zref.size(), y.size());
|
|
|
|
@ -88,9 +87,8 @@ void TestTartgetFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestXYZNKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT)
|
|
|
|
|
<< ", size: " << d;
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
|
|
|
|
@ -119,7 +117,7 @@ void TestXYZNKernel() {
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::XYZNTuples<T>, PlaceType>(d);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
VLOG(10) << "Test Jitcode Kernel, size: " << d;
|
|
|
|
|
TestTartgetFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, zref);
|
|
|
|
|
TestXYZNFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, zref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test all impls in more
|
|
|
|
@ -134,14 +132,14 @@ void TestXYZNKernel() {
|
|
|
|
|
if (i && i->UseMe(d)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
VLOG(10) << "Test More Kernel, size: " << d;
|
|
|
|
|
TestTartgetFunc<T, jit::XYZNTuples<T>>(more, x, y, zref);
|
|
|
|
|
TestXYZNFunc<T, jit::XYZNTuples<T>>(more, x, y, zref);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
VLOG(10) << "Test Get function, size: " << d;
|
|
|
|
|
auto tgt = jit::Get<KT, jit::XYZNTuples<T>, PlaceType>(d);
|
|
|
|
|
TestTartgetFunc<T, jit::XYZNTuples<T>>(tgt, x, y, zref);
|
|
|
|
|
TestXYZNFunc<T, jit::XYZNTuples<T>>(tgt, x, y, zref);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -169,4 +167,89 @@ TEST(JITKernel, vsub) {
|
|
|
|
|
TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, pool) {}
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
void TestAXYNFunc(const typename KernelTuples::func_type tgt, const T a,
|
|
|
|
|
const std::vector<T>& x, const std::vector<T>& yref) {
|
|
|
|
|
EXPECT_TRUE(tgt != nullptr);
|
|
|
|
|
EXPECT_EQ(yref.size(), x.size());
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
const T* yref_data = yref.data();
|
|
|
|
|
const int d = yref.size();
|
|
|
|
|
std::vector<T> ytgt(d);
|
|
|
|
|
T* ytgt_data = ytgt.data();
|
|
|
|
|
// test normal
|
|
|
|
|
tgt(&a, x_data, ytgt_data, d);
|
|
|
|
|
ExpectEQ<T>(ytgt_data, yref_data, d);
|
|
|
|
|
// test inplace x
|
|
|
|
|
std::copy(x.begin(), x.end(), ytgt.begin());
|
|
|
|
|
tgt(&a, ytgt_data, ytgt_data, d);
|
|
|
|
|
ExpectEQ<T>(ytgt_data, yref_data, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void TestAXYNKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
|
|
|
|
|
EXPECT_TRUE(ref != nullptr);
|
|
|
|
|
|
|
|
|
|
const T a = static_cast<T>(3);
|
|
|
|
|
std::vector<T> x(d), yref(d);
|
|
|
|
|
std::vector<T> xinp(d); // inplace test
|
|
|
|
|
RandomVec<T>(d, x.data());
|
|
|
|
|
std::copy(x.begin(), x.end(), xinp.begin());
|
|
|
|
|
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
T* yref_data = yref.data();
|
|
|
|
|
T* xinp_data = xinp.data();
|
|
|
|
|
// test refer code inplace
|
|
|
|
|
ref(&a, x_data, yref_data, d);
|
|
|
|
|
ref(&a, xinp_data, xinp_data, d);
|
|
|
|
|
ExpectEQ<T>(xinp_data, yref_data, d);
|
|
|
|
|
|
|
|
|
|
// test jitcode
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::AXYNTuples<T>, PlaceType>(d);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
VLOG(10) << "Test Jitcode Kernel, size: " << d;
|
|
|
|
|
TestAXYNFunc<T, jit::AXYNTuples<T>>(jitcode, a, x, yref);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test all impls in more
|
|
|
|
|
jit::KernelKey kkey(KT, PlaceType());
|
|
|
|
|
auto& pool = jit::KernelPool().Instance().AllKernels();
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i = dynamic_cast<const jit::KernelImpl<jit::AXYNTuples<T>>*>(
|
|
|
|
|
impl.get());
|
|
|
|
|
if (i && i->UseMe(d)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
VLOG(10) << "Test More Kernel, size: " << d;
|
|
|
|
|
TestAXYNFunc<T, jit::AXYNTuples<T>>(more, a, x, yref);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
VLOG(10) << "Test Get function, size: " << d;
|
|
|
|
|
auto tgt = jit::Get<KT, jit::AXYNTuples<T>, PlaceType>(d);
|
|
|
|
|
TestAXYNFunc<T, jit::AXYNTuples<T>>(tgt, a, x, yref);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vscal) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestAXYNKernel<jit::vscal, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::vscal, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, vaddbias) {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
TestAXYNKernel<jit::vaddbias, float, paddle::platform::CPUPlace>();
|
|
|
|
|
TestAXYNKernel<jit::vaddbias, double, paddle::platform::CPUPlace>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(JITKernel, pool) {
|
|
|
|
|
// TODO(TJ): add some test
|
|
|
|
|
}
|
|
|
|
|