|
|
|
@ -272,6 +272,98 @@ void BenchXYNKernel() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return this function avg time
|
|
|
|
|
template <typename T, typename KernelTuples>
|
|
|
|
|
double BenchLSTMFunc(const typename KernelTuples::func_type tgt,
|
|
|
|
|
const paddle::operators::jit::lstm_attr_t* attr,
|
|
|
|
|
paddle::operators::jit::lstm_t* step) {
|
|
|
|
|
for (int i = 0; i < FLAGS_burning; ++i) {
|
|
|
|
|
tgt(step, attr);
|
|
|
|
|
}
|
|
|
|
|
auto start = GetCurrentUS();
|
|
|
|
|
for (int i = 0; i < FLAGS_repeat; ++i) {
|
|
|
|
|
tgt(step, attr);
|
|
|
|
|
}
|
|
|
|
|
auto end = GetCurrentUS();
|
|
|
|
|
return (end - start) / FLAGS_repeat;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
|
|
|
|
void BenchLSTMKernel() {
|
|
|
|
|
namespace jit = paddle::operators::jit;
|
|
|
|
|
for (bool use_peephole : {true, false}) {
|
|
|
|
|
for (int d : TestSizes()) {
|
|
|
|
|
const jit::lstm_attr_t attr(d, jit::vsigmoid, jit::vtanh, jit::vtanh,
|
|
|
|
|
use_peephole);
|
|
|
|
|
std::vector<std::pair<std::string, double>> infos;
|
|
|
|
|
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
|
|
|
|
|
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
|
|
|
|
|
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
|
|
|
|
|
const T* ct_1_data = ct_1.data();
|
|
|
|
|
const T* wp_data = wp.data();
|
|
|
|
|
T* x_data = x.data();
|
|
|
|
|
T* checked_data = checked.data();
|
|
|
|
|
T* ct_data = ct.data();
|
|
|
|
|
T* ht_data = ht.data();
|
|
|
|
|
jit::lstm_t step;
|
|
|
|
|
step.gates = x_data;
|
|
|
|
|
step.ct_1 = ct_1_data;
|
|
|
|
|
step.ct = ct_data;
|
|
|
|
|
step.ht = ht_data;
|
|
|
|
|
if (use_peephole) {
|
|
|
|
|
step.wp = wp_data;
|
|
|
|
|
step.checked = checked_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// test refer
|
|
|
|
|
auto refer = jit::GetRefer<KT, jit::LSTMTuples<T>>();
|
|
|
|
|
if (refer) {
|
|
|
|
|
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(refer, &attr, &step);
|
|
|
|
|
infos.push_back(std::make_pair("Refer", res));
|
|
|
|
|
}
|
|
|
|
|
// test jitcode
|
|
|
|
|
auto jitcode = jit::GetJitCode<KT, jit::LSTMTuples<T>, PlaceType>(attr);
|
|
|
|
|
if (jitcode) {
|
|
|
|
|
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(jitcode, &attr, &step);
|
|
|
|
|
infos.push_back(std::make_pair("JitCode", res));
|
|
|
|
|
}
|
|
|
|
|
// test all impls in more
|
|
|
|
|
jit::KernelKey kkey(KT, PlaceType());
|
|
|
|
|
auto& pool = jit::KernelPool().Instance().AllKernels();
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i = dynamic_cast<const jit::KernelImpl<jit::LSTMTuples<T>>*>(
|
|
|
|
|
impl.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
auto more = i->GetFunc();
|
|
|
|
|
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(more, &attr, &step);
|
|
|
|
|
infos.push_back(std::make_pair("More", res));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Test result from Get function
|
|
|
|
|
auto tgt = jit::Get<KT, jit::LSTMTuples<T>, PlaceType>(attr);
|
|
|
|
|
if (!tgt) {
|
|
|
|
|
LOG(ERROR) << "Target can not be empty!";
|
|
|
|
|
}
|
|
|
|
|
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(tgt, &attr, &step);
|
|
|
|
|
infos.push_back(std::make_pair("Target", res));
|
|
|
|
|
// print
|
|
|
|
|
std::ostringstream loginfos;
|
|
|
|
|
loginfos << "Kernel Type: " << jit::to_string(KT)
|
|
|
|
|
<< ", Sigmoid,Tanh,Tanh, " << (use_peephole ? "Peephole_" : "")
|
|
|
|
|
<< " size " << d << ": ";
|
|
|
|
|
for (auto pair : infos) {
|
|
|
|
|
loginfos << pair.first << " takes " << pair.second << " us; ";
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << loginfos.str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Benchmark all jit kernels including jitcode, mkl and refer.
|
|
|
|
|
// To use this tool, run command: ./benchmark [options...]
|
|
|
|
|
// Options:
|
|
|
|
@ -294,9 +386,14 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
BenchAXYNKernel<jit::vscal, T, PlaceType>();
|
|
|
|
|
BenchAXYNKernel<jit::vaddbias, T, PlaceType>();
|
|
|
|
|
|
|
|
|
|
// act
|
|
|
|
|
BenchXYNKernel<jit::vrelu, T, PlaceType>();
|
|
|
|
|
BenchXYNKernel<jit::videntity, T, PlaceType>();
|
|
|
|
|
BenchXYNKernel<jit::vexp, T, PlaceType>();
|
|
|
|
|
BenchXYNKernel<jit::vsigmoid, T, PlaceType>();
|
|
|
|
|
BenchXYNKernel<jit::vtanh, T, PlaceType>();
|
|
|
|
|
|
|
|
|
|
// lstm and peephole
|
|
|
|
|
BenchLSTMKernel<jit::lstmctht, T, PlaceType>();
|
|
|
|
|
BenchLSTMKernel<jit::lstmc1h1, T, PlaceType>();
|
|
|
|
|
}
|
|
|
|
|