simplify the jitkernel templates and tests

test=develop
revert-16045-imperative_remove_desc
tensor-tang 6 years ago
parent 802f362ac4
commit 14a764c930

@ -82,9 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::KernelFuncs<jit::kCRFDecoding, jit::CRFDecodingTuples<T>, auto ker =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
.At(tag_num); .At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;

@ -110,10 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
jit::KernelFuncs<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache() .At(0);
.At(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {

@ -53,8 +53,7 @@ struct EmbeddingVSumFunctor {
for (size_t i = 0; i != ids_lod.size() - 1; ++i) { for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i]; attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = auto emb_seqpool =
jit::KernelFuncs<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>, jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(attr); .At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr); &attr);
@ -138,8 +137,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto vbroadcast = auto vbroadcast =
jit::KernelFuncs<jit::kVBroadcast, jit::VBroadcastTuples<T>, jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(out_width); .At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);

@ -182,32 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const jit::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \ auto ComputeH1 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \
.At(attr); \ attr); \
auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \ auto ComputeHtPart1 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \ .At(attr); \
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \ auto ComputeHtPart2 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \ .At(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {

@ -235,34 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const jit::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \ use_peepholes); \
jit::lstm_t one_step; \ jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = jit::KernelFuncs<jit::kLSTMC1H1, jit::LSTMTuples<T>, \ auto ComputeC1H1 = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
.At(attr); \ attr); \
auto ComputeCtHt = jit::KernelFuncs<jit::kLSTMCtHt, jit::LSTMTuples<T>, \ auto ComputeCtHt = \
platform::CPUPlace>::Cache() \ jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
.At(attr) attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \

@ -81,12 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() {
template <typename T> template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) { const jit::matmul_attr_t& attr) {
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>, auto matmul =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
auto addbias_relu = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>, auto addbias_relu =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr.n); attr.n);
matmul(x, w, y, &attr); matmul(x, w, y, &attr);
T* dst = y; T* dst = y;
for (int i = 0; i < attr.m; ++i) { for (int i = 0; i < attr.m; ++i) {

@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt; attr.type = jit::SeqPoolType::kSqrt;
} }
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>, auto seqpool =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {

@ -93,24 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
attr.n = y_dims[1]; attr.n = y_dims[1];
int o_numel = attr.m * attr.n; int o_numel = attr.m * attr.n;
auto vsquare_x = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>, auto vsquare_x =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr.m * attr.k); attr.m * attr.k);
auto vsquare_y = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>, auto vsquare_y =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr.k * attr.n); attr.k * attr.n);
auto vsquare_xy = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>, auto vsquare_xy =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
.At(o_numel); o_numel);
auto vsub = jit::KernelFuncs<jit::kVSub, jit::XYZNTuples<T>, auto vsub =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
.At(o_numel); o_numel);
auto vscal = jit::KernelFuncs<jit::kVScal, jit::AXYNTuples<T>, auto vscal =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
.At(o_numel); o_numel);
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>, auto matmul =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();

File diff suppressed because it is too large Load Diff

@ -19,6 +19,8 @@ extern "C" {
} }
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <unordered_map>
#include <utility> // for std::move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
@ -30,22 +32,22 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value && std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type typename KernelTuple::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuples::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance(); auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>(); return codes.AllKernels().at(key)->template getCode<Func>();
} }
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey); auto iter = creator_map.find(kkey);
@ -66,27 +68,27 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value || !std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type typename KernelTuple::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr; return nullptr;
} }
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples> template <typename KernelTuple>
inline typename KernelTuples::func_type GetRefer() { inline typename KernelTuple::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels(); auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(), PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function."); "Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) { if (i) {
return i->GetFunc(); return i->GetFunc();
} }
@ -94,23 +96,22 @@ inline typename KernelTuples::func_type GetRefer() {
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename PlaceType = platform::CPUPlace> typename KernelTuple::func_type Get(
typename KernelTuples::func_type Get( const typename KernelTuple::attr_type& attr) {
const typename KernelTuples::attr_type& attr) { auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr);
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) { if (jitfunc) {
return jitfunc; return jitfunc;
} }
// pool: (KernelKey(type, place), vector<KernelPtr>) // pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->UseMe(attr)) {
return i->GetFunc(); return i->GetFunc();
} }
@ -118,48 +119,50 @@ typename KernelTuples::func_type Get(
} }
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>(); return GetRefer<KernelTuple>();
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
class KernelFuncs { class KernelFuncs {
public: public:
KernelFuncs() = default; KernelFuncs() = default;
static KernelFuncs& Cache() { static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache; static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache; return g_func_cache;
} }
// the exposed interface to use // the exposed interface to use
typename KernelTuples::func_type At( typename KernelTuple::func_type At(
const typename KernelTuples::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
// XXH64: 13.8 GB/s // XXH64: 13.8 GB/s
int64_t key = XXH64(&attr, sizeof(typename KernelTuples::attr_type), 0); // TODO(TJ): change me, maybe not all attr change need one key, should be
// attrkey
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
if (Has(key)) { if (Has(key)) {
return funcs_.at(key); return funcs_.at(key);
} }
// If do not have this attr in cache, // If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one. // then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one. // Here just get the offline benchmarked best one.
auto func = Get<KT, KernelTuples, PlaceType>(attr); auto func = Get<KernelTuple, PlaceType>(attr);
Insert(key, func); Insert(key, func);
return func; return func;
} }
typename KernelTuples::func_type operator[]( typename KernelTuple::func_type operator[](
const typename KernelTuples::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
return At(attr); return At(attr);
} }
protected: protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); } bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuples::func_type func) { void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func); funcs_.emplace(key, func);
} }
private: private:
std::unordered_map<int64_t, typename KernelTuples::func_type> funcs_; std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };

@ -62,26 +62,55 @@ typedef enum {
kSqrt, kSqrt,
} SeqPoolType; } SeqPoolType;
// x, y, z, n
template <typename T> template <typename T>
struct XYZNTuples { struct XYZNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int); typedef void (*func_type)(const T*, const T*, T*, int);
}; };
// a, x, y, n
template <typename T> template <typename T>
struct AXYNTuples : public XYZNTuples<T> {}; struct AXYNTuple : public XYZNTuple<T> {};
// x, y, n
template <typename T> template <typename T>
struct XYNTuples { struct XYNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int); typedef void (*func_type)(const T*, T*, int);
}; };
// x, return and int // x, returned value, n
template <typename T> template <typename T>
struct XRNTuples : public XYNTuples<T> {}; struct XRNTuple : public XYNTuple<T> {};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
static constexpr KernelType kernel_type = k##type; \
}
// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t; typedef struct lstm_attr_s lstm_attr_t;
template <typename T> template <typename T>
struct LSTMTuples { struct LSTMTuple {
typedef T data_type; typedef T data_type;
typedef lstm_attr_t attr_type; typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*); typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
}; };
template <typename T> template <typename T>
struct GRUTuples { struct GRUTuple {
typedef T data_type; typedef T data_type;
typedef gru_attr_t attr_type; typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
#undef DECLARE_KERNELTUPLE
template <typename T> template <typename T>
struct VBroadcastTuples { struct VBroadcastTuple {
static constexpr KernelType kernel_type = kVBroadcast;
typedef T data_type; typedef T data_type;
typedef int64_t attr_type; typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t); typedef void (*func_type)(const T*, T*, int64_t, int64_t);
@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
} seq_pool_attr_t; } seq_pool_attr_t;
template <typename T> template <typename T>
struct SeqPoolTuples { struct SeqPoolTuple {
static constexpr KernelType kernel_type = kSeqPool;
typedef T data_type; typedef T data_type;
typedef seq_pool_attr_t attr_type; typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
} emb_seq_pool_attr_t; } emb_seq_pool_attr_t;
template <typename T> template <typename T>
struct EmbSeqPoolTuples { struct EmbSeqPoolTuple {
static constexpr KernelType kernel_type = kEmbSeqPool;
typedef T data_type; typedef T data_type;
typedef emb_seq_pool_attr_t attr_type; typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const int64_t*, T*,
@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
} sgd_attr_t; } sgd_attr_t;
template <typename T> template <typename T>
struct SgdTuples { struct SgdTuple {
static constexpr KernelType kernel_type = kSgd;
typedef T data_type; typedef T data_type;
typedef sgd_attr_t attr_type; typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
} matmul_attr_t; } matmul_attr_t;
template <typename T> template <typename T>
struct MatMulTuples { struct MatMulTuple {
static constexpr KernelType kernel_type = kMatMul;
typedef T data_type; typedef T data_type;
typedef matmul_attr_t attr_type; typedef matmul_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*); typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
}; };
template <typename T> template <typename T>
struct CRFDecodingTuples { struct CRFDecodingTuple {
static constexpr KernelType kernel_type = kCRFDecoding;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
}; };
template <typename T> template <typename T>
struct LayerNormTuples { struct LayerNormTuple {
static constexpr KernelType kernel_type = kLayerNorm;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
@ -236,7 +281,8 @@ struct LayerNormTuples {
}; };
template <typename T> template <typename T>
struct SoftmaxTuples { struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int); typedef void (*func_type)(const T*, T*, int, int);
@ -244,7 +290,8 @@ struct SoftmaxTuples {
// nChw16c = nChw16c .* NC // nChw16c = nChw16c .* NC
template <typename T> template <typename T>
struct NCHW16CMulNCTuples { struct NCHW16CMulNCTuple {
static constexpr KernelType kernel_type = kNCHW16CMulNC;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int); typedef void (*func_type)(const T*, const T*, T*, int, int);
@ -258,12 +305,12 @@ class Kernel {
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
template <typename KernelTuples> template <typename KernelTuple>
class KernelMore : public Kernel { class KernelMore : public Kernel {
public: public:
using T = typename KernelTuples::data_type; using T = typename KernelTuple::data_type;
using Func = typename KernelTuples::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; } virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0; virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0; virtual const char* ImplType() const = 0;
@ -272,11 +319,11 @@ class KernelMore : public Kernel {
Func func{nullptr}; Func func{nullptr};
}; };
template <typename KernelTuples> template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuples> { class ReferKernel : public KernelMore<KernelTuple> {
public: public:
// Refer code can always be used // Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override { bool UseMe(const typename KernelTuple::attr_type& attr) const override {
return true; return true;
} }
const char* ImplType() const override { return "Refer"; } const char* ImplType() const override { return "Refer"; }

@ -26,11 +26,10 @@ namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w, void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num); float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> { class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public: public:
CRFDecodingKernel() { this->func = CRFDecoding; } CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe( bool UseMe(const typename CRFDecodingTuple<float>::attr_type&) const override;
const typename CRFDecodingTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };

@ -27,10 +27,10 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height, const float* scale, const float* bias, int height,
const float epsilon, int right); const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> { class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public: public:
LayerNormKernel() { this->func = LayerNorm; } LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override; bool UseMe(const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };

@ -23,6 +23,8 @@ namespace jit {
namespace more { namespace more {
namespace mix { namespace mix {
using CPUPlace = platform::CPUPlace;
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n); auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
compute(y, y, n); compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) { void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1; const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n); auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
compute_scal(&a, x, y, n); compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n); compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n); compute_scal(&a, y, y, n);
@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax = auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vaddbias = auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n); KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) { if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVRelu) { } else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVTanh) { } else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d; const int d = attr->d;
const int d2 = d * 2; const int d2 = d * 2;
const int d3 = d * 3; const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2); auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2); auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3); auto act_gate_d3 = getActFunc(attr->act_gate, d3);
@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d; int d = attr->d;
int d2 = d * 2; int d2 = d * 2;
int d3 = d * 3; int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d); auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d); auto act_cell_d = getActFunc(attr->act_cell, d);
@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2; int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d); auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d); auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
act_gate(gates, gates, d); act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d); act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d); vmul_d(gates, gates + d2, ht, d);
@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d); auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d); act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d); vmul_d(ht_1, gates + attr->d, ht, attr->d);
} }
@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
namespace mix = paddle::operators::jit::more::mix; namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \ #define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid); REGISTER_MORE_KERNEL(VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh); REGISTER_MORE_KERNEL(VTanh);
REGISTER_MORE_KERNEL(kSoftmax, Softmax); REGISTER_MORE_KERNEL(Softmax);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_MORE_KERNEL(LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_MORE_KERNEL(LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1); REGISTER_MORE_KERNEL(GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_MORE_KERNEL(GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_MORE_KERNEL(GRUHtPart2);
#undef REGISTER_MORE_KERNEL #undef REGISTER_MORE_KERNEL

@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \ #define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name; } \ name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \ const char* ImplType() const override { return "Mixed"; } \
} }
// XYN // XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); DECLARE_MORE_KERNEL(VSigmoid);
DECLARE_MORE_KERNEL(VTanh, XYNTuples); DECLARE_MORE_KERNEL(VTanh);
// XRN // XRN
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples); DECLARE_MORE_KERNEL(Softmax);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_MORE_KERNEL(LSTMCtHt);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_MORE_KERNEL(LSTMC1H1);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples); DECLARE_MORE_KERNEL(GRUH1);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart1);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart2);
#undef DECLARE_MORE_KERNEL #undef DECLARE_MORE_KERNEL

@ -250,23 +250,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
namespace mkl = paddle::operators::jit::more::mkl; namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \ #define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \ REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>) mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul); REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy); REGISTER_MKL_KERNEL(VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast); REGISTER_MKL_KERNEL(VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax); REGISTER_MKL_KERNEL(Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd); REGISTER_MKL_KERNEL(Sgd);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL

@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \ const char* ImplType() const override { return "MKL"; } \
} }
// ABCMNK // ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples); DECLARE_MKL_KERNEL(MatMul);
// XYZN // XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VMul);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples); DECLARE_MKL_KERNEL(VAdd);
// AXYN // AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples); DECLARE_MKL_KERNEL(VScal);
// XYN // XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VExp);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh);
DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(VSquare);
DECLARE_MKL_KERNEL(VCopy, XYNTuples); DECLARE_MKL_KERNEL(VCopy);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); // others
DECLARE_MKL_KERNEL(SeqPool);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_MKL_KERNEL(EmbSeqPool);
DECLARE_MKL_KERNEL(Softmax);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Sgd);
DECLARE_MKL_KERNEL(VBroadcast);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL

@ -17,51 +17,43 @@
namespace refer = paddle::operators::jit::refer; namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \ #define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \ REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
refer::func##Kernel<double>) refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul); REGISTER_REFER_KERNEL(VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd); REGISTER_REFER_KERNEL(VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu); REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub); REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(kVScal, VScal); REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy); REGISTER_REFER_KERNEL(VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); REGISTER_REFER_KERNEL(VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh); REGISTER_REFER_KERNEL(VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_REFER_KERNEL(LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_REFER_KERNEL(LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1); REGISTER_REFER_KERNEL(GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_REFER_KERNEL(GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_REFER_KERNEL(GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding); REGISTER_REFER_KERNEL(CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(LayerNorm);
REGISTER_REFER_KERNEL(NCHW16CMulNC);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool); REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(kMatMul, MatMul); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(kHMax, HMax); REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(kHSum, HSum); REGISTER_REFER_KERNEL(VBroadcast);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL

@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
} }
// const T* x, const T* y, T* z, int n // const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples); DECLARE_REFER_KERNEL(VMul);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples); DECLARE_REFER_KERNEL(VAdd);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples); DECLARE_REFER_KERNEL(VAddRelu);
DECLARE_REFER_KERNEL(VSub, XYZNTuples); DECLARE_REFER_KERNEL(VSub);
// const T* a, const T* x, T* y, int n // const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples); DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples); DECLARE_REFER_KERNEL(VAddBias);
// const T* x, T* y, int n // const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples); DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples); DECLARE_REFER_KERNEL(VIdentity);
DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VExp);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh);
DECLARE_REFER_KERNEL(VSquare, XYNTuples); DECLARE_REFER_KERNEL(VSquare);
DECLARE_REFER_KERNEL(VCopy, XYNTuples); DECLARE_REFER_KERNEL(VCopy);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_REFER_KERNEL(LSTMC1H1);
// gru_t*, const gru_attr_t* // gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples); DECLARE_REFER_KERNEL(GRUH1);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart1);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); // others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); DECLARE_REFER_KERNEL(LayerNorm);
DECLARE_REFER_KERNEL(NCHW16CMulNC);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples); DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(HMax, XRNTuples); DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(HSum, XRNTuples); DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); DECLARE_REFER_KERNEL(VBroadcast);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL

File diff suppressed because it is too large Load Diff

@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right); PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = jit::KernelFuncs<jit::kLayerNorm, jit::LayerNormTuples<T>, auto ker =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right); .At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right); static_cast<const float>(epsilon), right);

@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
.At(N); N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
.At(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif

@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit::seq_pool_attr_t attr( jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]), static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum); jit::SeqPoolType::kSum);
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>, auto seqpool =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr); .At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]); attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);

@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
auto compute_softmax = auto compute_softmax =
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>, jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]); .At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
} }

@ -47,9 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
int64_t rows_idx = 0; int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>, auto sgd =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
@ -82,9 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
attr.selected_rows_size = grad_rows.size(); attr.selected_rows_size = grad_rows.size();
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>, auto sgd =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
.At(attr); attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");

Loading…
Cancel
Save