Merge pull request #14958 from tensor-tang/refine/jit

enhance jit
revert-15207-remove_op_handle_lock_and_fix_var
tensor-tang 6 years ago committed by GitHub
commit 693e5e65ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,6 +16,7 @@ add_subdirectory(metrics)
add_subdirectory(optimizers)
add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops)
add_subdirectory(jit)
if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
@ -64,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)

@ -16,7 +16,7 @@ limitations under the License. */
#include <limits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track;
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
const auto& ker = math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::CRFDecodeKernel<T>>(
static_cast<int>(tag_num));
ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) {

@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16;
int C = c / simd_width;
const auto& multiply =
math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>(0);
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
multiply(ptr_x, ptr_y, ptr_z, h, w);
}
}
}

@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation")); \
math::jitkernel::gru_t one_step; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \
const math::jitkernel::gru_attr_t&>(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart1 = \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const {
@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} else {
one_step.gates = xx_data;
one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr);
ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data;
tstart = 1;
move_step();
@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr);
ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3);
ker->ComputeHtPart2(&one_step, &attr);
ComputeHtPart2(&one_step, &attr);
// save prev
prev_hidden_data = hidden_out_data;
move_step();
@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for (int i = 0; i < max_bs; ++i) {
one_step.gates = cur_in_data;
one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr);
ComputeH1(&one_step, &attr);
// add offset
cur_in_data += D3;
cur_out_data += D;
@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr);
ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3;
cur_prev_hidden_data += D;
@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr);
ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3;
cur_prev_hidden_data += D;
cur_out_data += D;

@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \
const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const math::jitkernel::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
math::jitkernel::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, \
const math::jitkernel::lstm_attr_t&>(attr)
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \
jit::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = xx_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr);
ComputeC1H1(&one_step, &attr);
tstart = 1;
// move one step
prev_h_data = h_out_data;
@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data;
one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr);
ComputeCtHt(&one_step, &attr);
// move one step
prev_h_data = h_out_data;
prev_c_data = c_out_data;
@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr);
ComputeC1H1(&one_step, &attr);
cur_in_data += D4;
cur_c_out_data += D;
@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr);
ComputeCtHt(&one_step, &attr);
// move one batch
cur_in_data += D4;

@ -0,0 +1,25 @@
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
# refer must go first
add_subdirectory(refer)
add_subdirectory(more)
if(WITH_XBYAK)
add_subdirectory(gen)
endif()
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
if(NOT WIN32)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper)
endif()

@ -0,0 +1,66 @@
# JIT Kernel
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元更侧重的是在不同硬件上的性能。可以有多重第三方库的实现每种实现有自己的`UseMe`函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法比如Vector MUL 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
## 目录结构
```txt
PaddlePaddle/Paddle/paddle/fluid/
├── ...
├── operator/
│ ├── .../
└── jit/
├── ...
├── gen/
│ └── ...
|── more/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
└── ...
```
基本类的定义都放在根目录下根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现每种kernel算子都需要有reference的实现用作单元测试的基准其他的实现都是可选的。
- gen: 代表使用jit生成的code需要依赖xbyak库。该实现最关心的就是性能。
- refer: 代表reference的实现每种kernel算子都需要有在CPU上的reference的实现他主要关心的算法逻辑的正确性。
- more: 下面可以放入跟多实现可以包括mklmkldnnintrinsicopenblas等也可以是自身已有的kernel组合。
## 动态获取
提供一个`jit::Get`方法根据kernel类别获取每种实现都有自己的使用范围根据范围动态和当前条件选择需要的kernel函数。
## 测试
- 逻辑测试
所有实现都要与refer的code对比需要满足精度要求 包括float和double的数据类型
- 性能测试
所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要在各种条件下都是最好的。
# 如何添加新的算子
- 在`KernelType` 中添加 `your_key` .
- 实现Reference 的逻辑这个是必须是在CPU上的实现并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel.
- (optional) 实现更多的算法在`more`目录下可以依赖mklintrinsic或者mkldnn等第三方库。
- (optional) 实现基于Xbyak的生成code在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`并注册在与refer相同的`KernelType`上。
- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`新加的Attr类型需要特例化`JitCodeKey`方法。
- 在`test.cc`中添加unit test至少需要测试`float`和`double`两种数据类型,如有必要需要支持额外的数据类型,比如`int8`的相关函数。
- 在`benchmark.cc`中添加相应的性能对比同一种kernel需要对比所有实现并且确保`jit::Get`得到的实现一直是速度最快的。
# 优点
- 统一的Get方法接口简单。
- 同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。
- 目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。
- 优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。
- 可以支持多种平台包括LinuxMac 和 Windows至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口不必关心底层实现。

@ -0,0 +1,231 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times.");
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
std::mt19937 rng(seed);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i <= FLAGS_max_size; ++i) {
s.push_back(i);
}
return s;
}
template <typename KernelTuples, typename... Args>
struct BenchFunc {
// return this function avg time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
auto start = paddle::platform::PosixInNsec() / 1e-3;
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(args...);
}
auto end = paddle::platform::PosixInNsec() / 1e-3;
return static_cast<double>(end - start) / FLAGS_repeat;
}
};
namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KT, KernelTuples>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
infos.push_back(std::make_pair("Target", benchmark(tgt, args...)));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d), z(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(),
z.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(),
d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
use_peephole);
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
std::vector<T> x(3 * d), ht_1(d), ht(d);
RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_data = ht.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
<< " times.";
using T = float;
using PlaceType = paddle::platform::CPUPlace;
// xyzn
BenchXYZNKernel<jit::kVMul, T, PlaceType>();
BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
BenchXYZNKernel<jit::kVSub, T, PlaceType>();
// axyn
BenchAXYNKernel<jit::kVScal, T, PlaceType>();
BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
// lstm and peephole
BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
// gru functions
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
}

@ -0,0 +1,28 @@
file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
function(USE_JITKERNEL_GEN TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
USE_JITKERNEL_GEN(kVTanh)
USE_JITKERNEL_GEN(kLSTMCtHt)
USE_JITKERNEL_GEN(kLSTMC1H1)
USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)

@ -0,0 +1,135 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)};
int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0};
void VActJitCode::genCode() {
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
act<ymm_t>(ymm_dst, ymm_src, type_);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) {
block = 2;
vmovq(xmm_src, ptr[param1 + offset]);
} else {
block = 1;
vmovss(xmm_src, ptr[param1 + offset]);
}
act<xmm_t>(xmm_dst, xmm_src, type_);
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
size_t VReluCreator::CodeSize(const int& d) const {
return 96 /* init size */ +
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
8 /* average bytes for each instruction */;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VExpCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8;
}
size_t VSigmoidCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8;
}
size_t VTanhCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8;
}
#undef DECLARE_ACT_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);

@ -0,0 +1,186 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void VXXJitCode::genCode() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::MUL) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
} else if (rest >= 2) {
block = 2;
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
}
} else {
block = 1;
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
}
switch (type_) {
case operand_type::MUL:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
void NCHW16CMulNCJitCode::genCode() {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr));
}
};
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_BLAS_CREATOR(VMul);
DECLARE_BLAS_CREATOR(VAdd);
DECLARE_BLAS_CREATOR(VSub);
DECLARE_BLAS_CREATOR(VAddRelu);
DECLARE_BLAS_CREATOR(VScal);
DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);

@ -0,0 +1,117 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::MUL) {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
void genCode() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
#undef DECLARE_BLAS_JITCODE
// nChw16c = nChw16c .* NC
class NCHW16CMulNCJitCode : public JitCode {
public:
DECLARE_JIT_CODE(NCHW16CMulNCJitCode);
explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
void genCode() override;
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,116 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void GRUJitCode::genCode() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
#define DECLARE_GRU_CREATOR(name) \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const gru_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_GRU_CREATOR(GRUH1);
DECLARE_GRU_CREATOR(GRUHtPart1);
DECLARE_GRU_CREATOR(GRUHtPart2);
#undef DECLARE_GRU_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);

@ -0,0 +1,113 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class GRUJitCode : public VActFunc {
public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
this->genCode();
}
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
void genCode() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#define DECLARE_GRU_JITCODE(name, id) \
class name##JitCode : public GRUJitCode { \
public: \
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: GRUJitCode(id, attr, code_size, code_ptr) {} \
};
DECLARE_GRU_JITCODE(GRUH1, 0);
DECLARE_GRU_JITCODE(GRUHtPart1, 1);
DECLARE_GRU_JITCODE(GRUHtPart2, 2);
#undef DECLARE_GRU_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,126 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX);
constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
MUL = 0,
ADD,
SUB,
RELU,
EXP,
SIGMOID,
TANH,
IDENTITY
} operand_type;
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
virtual void preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}
virtual void postCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
}
ret();
}
void L(const char* label) { Xbyak::CodeGenerator::L(label); }
void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false) {
int scale = 0;
// Learn from https://github.com/intel/mkl-dnn
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
offt = offt - 2 * EVEX_max_8b_offt;
scale = 1;
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
offt = offt - 4 * EVEX_max_8b_offt;
scale = 2;
}
auto re = Xbyak::RegExp() + base + offt;
if (scale) {
re = re + reg_EVEX_max_8b_offt * scale;
}
if (bcast) {
return zword_b[re];
} else {
return zword[re];
}
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,142 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void LSTMJitCode::genCode() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
#define DECLARE_LSTM_CREATOR(name) \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const lstm_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_LSTM_CREATOR(LSTMCtHt);
DECLARE_LSTM_CREATOR(LSTMC1H1);
#undef DECLARE_LSTM_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);

@ -0,0 +1,118 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class LSTMJitCode : public VActFunc {
public:
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size, void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr),
num_(attr.d),
compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
this->genCode();
}
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
void genCode() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \
class name##JitCode : public LSTMJitCode { \
public: \
explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \
};
DECLARE_LSTM_JITCODE(LSTMCtHt, false);
DECLARE_LSTM_JITCODE(LSTMC1H1, true);
#undef DECLARE_LSTM_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,43 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;
std::ostringstream filename;
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
counter++;
std::ofstream fout(filename.str(), std::ios::out);
if (fout.is_open()) {
fout.write(reinterpret_cast<const char*>(code), this->getSize());
fout.close();
}
}
}
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,70 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
DECLARE_bool(dump_jitcode);
namespace paddle {
namespace operators {
namespace jit {
class GenBase : public Kernel {
public:
virtual ~GenBase() = default;
virtual const char* name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename Func>
Func getCode() {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
}
return reinterpret_cast<Func>(const_cast<unsigned char*>(code));
}
protected:
void dumpCode(const unsigned char* code) const;
};
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class GenCreator {
public:
virtual ~GenCreator() = default;
};
template <typename Attr>
class JitCodeCreator : public GenCreator {
public:
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
// create this code
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
};
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,76 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
#define ONE_CASE(key) \
case key: \
return #key
const char* to_string(KernelType kt) {
switch (kt) {
ONE_CASE(kVMul);
ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu);
ONE_CASE(kVSub);
ONE_CASE(kVScal);
ONE_CASE(kVAddBias);
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
ONE_CASE(kLSTMC1H1);
ONE_CASE(kGRUH1);
ONE_CASE(kGRUHtPart1);
ONE_CASE(kGRUHtPart2);
ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
}
return nullptr;
}
#undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
std::string lower = act;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower == "relu" || lower == "vrelu") {
return kVRelu;
} else if (lower == "identity" || lower == "videntity" || lower == "") {
return kVIdentity;
} else if (lower == "exp" || lower == "vexp") {
return kVExp;
} else if (lower == "sigmoid" || lower == "vsigmoid") {
return kVSigmoid;
} else if (lower == "tanh" || lower == "vtanh") {
return kVTanh;
}
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
return kNone;
}
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,140 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>();
}
const char* to_string(KernelType kt);
KernelType to_kerneltype(const std::string& act);
inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "],act_cell["
<< to_string(attr.act_cell) << "],use_peephole["
<< (attr.use_peephole ? "True" : "False") << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "]";
return os;
}
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,172 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
namespace jit {
typedef enum {
kNone = 0,
kVMul = 1,
kVAdd = 2,
kVAddRelu,
kVSub,
kVScal,
kVAddBias,
kVRelu,
kVIdentity,
kVExp,
kVSigmoid,
kVTanh,
kLSTMCtHt,
kLSTMC1H1,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
kCRFDecoding,
kLayerNorm,
kNCHW16CMulNC,
} KernelType;
template <typename T>
struct XYZNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
};
template <typename T>
struct AXYNTuples : public XYZNTuples<T> {};
template <typename T>
struct XYNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int);
};
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr}; // W_ic, W_fc, W_oc
void* checked{nullptr}; // size: 2 * d
} lstm_t;
typedef struct {
void* gates; // gates: {x_update, x_reset; x_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
KernelType act_gate, act_cand;
rnn_attr_s() = default;
explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
KernelType act_cell;
lstm_attr_s() = default;
explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
template <typename T>
struct GRUTuples {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
const float, int);
};
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// Just for adding to kernel pool without template
class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename KernelTuples>
class KernelMore : public Kernel {
public:
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
protected:
Func func{nullptr};
};
template <typename KernelTuples>
class ReferKernel : public KernelMore<KernelTuples> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
};
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,47 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace paddle {
namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
return d;
}
constexpr int act_type_shift = 3; // suppot 2^3 act types
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
}
} // namespace jit
} // namespace operators
} // namespace paddle

@ -0,0 +1,53 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
struct KernelKey {
struct Hash {
size_t operator()(const KernelKey& key) const {
int place = key.place_.which(); // less than 2^8
int type = static_cast<int>(key.type_) << 8; // less than 2^(32-8)
std::hash<int> hasher;
return hasher(place + type);
}
};
KernelType type_;
platform::Place place_;
KernelKey(KernelType type, platform::Place place)
: type_(type), place_(place) {}
size_t hash_key() const { return Hash()(*this); }
bool operator==(const KernelKey& o) const {
return platform::places_are_same_class(place_, o.place_) &&
type_ == o.type_;
}
bool operator!=(const KernelKey& o) const { return !(*this == o); }
};
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
} // namespace jit
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save