|
|
|
@ -52,9 +52,10 @@ std::vector<int> TestSizes() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return this function avg time
|
|
|
|
|
template <typename T, typename Func>
|
|
|
|
|
double BenchTartgetFunc(const Func tgt, const std::vector<T>& x,
|
|
|
|
|
const std::vector<T>& y, std::vector<T>& z) { // NOLINT
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
double BenchTartgetFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const std::vector<T>& x, const std::vector<T>& y,
|
|
|
|
|
std::vector<T>& z) { // NOLINT
|
|
|
|
|
const T* x_data = x.data();
|
|
|
|
|
const T* y_data = y.data();
|
|
|
|
|
const int d = z.size();
|
|
|
|
@ -71,40 +72,25 @@ double BenchTartgetFunc(const Func tgt, const std::vector<T>& x,
|
|
|
|
|
return (end - start) / FLAGS_repeat;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Benchmark all jit kernels including jitcode, mkl and refer.
|
|
|
|
|
// To use this tool, run command: ./benchmark [options...]
|
|
|
|
|
// Options:
|
|
|
|
|
// --burning: the burning time before count
|
|
|
|
|
// --repeat: the repeat times
|
|
|
|
|
// --max_size: the max size would be tested
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
|
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
|
|
|
|
google::InitGoogleLogging(argv[0]);
|
|
|
|
|
using T = float;
|
|
|
|
|
using PlaceType = paddle::platform::CPUPlace;
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void BenchXYZNKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
const auto KT = jit::vmul;
|
|
|
|
|
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
|
|
|
|
|
<< " times.";
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
// for (kernels type) { // TODO(TJ): more jit::KernelType
|
|
|
|
|
std::vector<std::pair<std::string, double>> infos;
|
|
|
|
|
std::vector<T> x(d), y(d), z(d);
|
|
|
|
|
RandomVec<T>(d, x.data());
|
|
|
|
|
RandomVec<T>(d, y.data());
|
|
|
|
|
// refer
|
|
|
|
|
auto refer = jit::GetRefer<KT, jit::VMulTuples<T>>();
|
|
|
|
|
auto refer = jit::GetRefer<KT, jit::XYZNTuples<T>>();
|
|
|
|
|
if (refer) {
|
|
|
|
|
auto res =
|
|
|
|
|
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(refer, x, y, z);
|
|
|
|
|
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(refer, x, y, z);
|
|
|
|
|
infos.push_back(std::make_pair("Refer", res));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test jitcode
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::XYZNTuples<T>, PlaceType>(d);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
auto res =
|
|
|
|
|
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, z);
|
|
|
|
|
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, z);
|
|
|
|
|
infos.push_back(std::make_pair("JitCode", res));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -115,32 +101,50 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
|
|
|
|
|
auto i = dynamic_cast<const jit::KernelImpl<jit::XYZNTuples<T>>*>(
|
|
|
|
|
impl.get());
|
|
|
|
|
if (i && i->UseMe(d)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
auto res =
|
|
|
|
|
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(more, x, y, z);
|
|
|
|
|
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(more, x, y, z);
|
|
|
|
|
infos.push_back(std::make_pair("More", res));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
|
|
|
|
|
auto tgt = jit::Get<KT, jit::XYZNTuples<T>, PlaceType>(d);
|
|
|
|
|
if (!tgt) {
|
|
|
|
|
LOG(ERROR) << "Target can not be empty!";
|
|
|
|
|
}
|
|
|
|
|
auto res = BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, z);
|
|
|
|
|
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(tgt, x, y, z);
|
|
|
|
|
infos.push_back(std::make_pair("Target", res));
|
|
|
|
|
|
|
|
|
|
// print
|
|
|
|
|
std::ostringstream loginfos;
|
|
|
|
|
loginfos << "Kernel Type: " << KT << ", size " << d << ": ";
|
|
|
|
|
loginfos << "Kernel Type: " << jit::to_string(KT) << ", size " << d << ": ";
|
|
|
|
|
for (auto pair : infos) {
|
|
|
|
|
loginfos << pair.first << " takes " << pair.second << " us; ";
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << loginfos.str();
|
|
|
|
|
// }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Benchmark all jit kernels including jitcode, mkl and refer.
|
|
|
|
|
// To use this tool, run command: ./benchmark [options...]
|
|
|
|
|
// Options:
|
|
|
|
|
// --burning: the burning time before count
|
|
|
|
|
// --repeat: the repeat times
|
|
|
|
|
// --max_size: the max size would be tested
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
|
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
|
|
|
|
google::InitGoogleLogging(argv[0]);
|
|
|
|
|
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
|
|
|
|
|
<< " times.";
|
|
|
|
|
using T = float;
|
|
|
|
|
using PlaceType = paddle::platform::CPUPlace;
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
BenchXYZNKernel<jit::vmul, T, PlaceType>();
|
|
|
|
|
BenchXYZNKernel<jit::vadd, T, PlaceType>();
|
|
|
|
|
BenchXYZNKernel<jit::vaddrelu, T, PlaceType>();
|
|
|
|
|
BenchXYZNKernel<jit::vsub, T, PlaceType>();
|
|
|
|
|
}
|
|
|
|
|