enhance the jitkernel helper and add unit tests

test=develop
revert-16045-imperative_remove_desc
tensor-tang 6 years ago
parent 14a764c930
commit 45bdd84dac

@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
BenchFunc<KernelTuple, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KernelTuple>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
for (auto f : funcs) {
infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) {

@ -81,7 +81,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override; \
bool CanBeUsed(const int& attr) const override; \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const {
bool VReluCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VSquareCreator::UseMe(const int& d) const {
bool VSquareCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VIdentityCreator::UseMe(const int& d) const {
bool VIdentityCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VExpCreator::UseMe(const int& d) const {
bool VExpCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32;
}
bool VSigmoidCreator::UseMe(const int& d) const {
bool VSigmoidCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VTanhCreator::UseMe(const int& d) const {
bool VTanhCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}

@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx) && attr <= 1024; \
} \
size_t CodeSize(const int& d) const override { \

@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override {
bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0;
}

@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
bool CanBeUsed(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \

@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \

@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const unsigned char* getCodeInternal() const override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}

@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
bool CanBeUsed(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \

@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
public:
bool UseMe(const matmul_attr_t& attr) const override {
bool CanBeUsed(const matmul_attr_t& attr) const override {
return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
}

@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public:
bool UseMe(const seq_pool_attr_t& attr) const override {
bool CanBeUsed(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx);
}
size_t CodeSize(const seq_pool_attr_t& attr) const override {

@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public:
bool UseMe(const sgd_attr_t& attr) const override {
bool CanBeUsed(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0;
}

@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class VBroadcastCreator : public JitCodeCreator<int64_t> {
public:
bool UseMe(const int64_t& w) const override {
bool CanBeUsed(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
}
size_t CodeSize(const int64_t& w) const override {

@ -31,7 +31,7 @@ namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
// refer do not need CanBeUsed, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;

@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual ~GenBase() = default;
virtual std::string name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
virtual const unsigned char* getCodeInternal() const = 0;
const char* ImplType() const override { return "JitCode"; }
template <typename Func>
Func getCode() {
Func getCode() const {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
virtual bool CanBeUsed(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;

@ -14,9 +14,6 @@
#pragma once
extern "C" {
#include <xxhash.h>
}
#include <iostream>
#include <string>
#include <unordered_map>
@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuple::func_type>::type
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance();
auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
return codes.AllKernels().at(key).get();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
if (i && i->CanBeUsed(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
auto res = p.get();
codes.Insert(key, std::move(p));
return f;
return res;
}
}
}
@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuple::func_type>::type
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr;
}
@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <typename KernelTuple>
inline typename KernelTuple::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) {
return i->GetFunc();
return i;
}
}
return nullptr;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type Get(
template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
return p->GetFunc();
}
// Return all Kernels that can be used
template <typename KernelTuple, typename PlaceType>
std::vector<const Kernel*> GetAllCandidateKernels(
const typename KernelTuple::attr_type& attr) {
auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
// the search order shoudl be jitcode > more > refer
std::vector<const Kernel*> res;
auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitker) {
res.emplace_back(jitker);
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
// more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto& pool = KernelPool::Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
if (i && i->CanBeUsed(attr)) {
res.emplace_back(i);
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KernelTuple>();
auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
res.emplace_back(ref);
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
std::vector<std::pair<std::string, Func>> res;
for (auto k : kers) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
std::vector<typename KernelTuple::func_type> res;
for (auto& i : funcs) {
res.emplace_back(i.second);
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return funcs[0];
}
template <typename KernelTuple, typename PlaceType>
@ -134,17 +187,13 @@ class KernelFuncs {
// the exposed interface to use
typename KernelTuple::func_type At(
const typename KernelTuple::attr_type& attr) {
// XXH64: 13.8 GB/s
// TODO(TJ): change me, maybe not all attr change need one key, should be
// attrkey
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
// Maybe here is not good enough, not all kernels should have jitcode
int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
if (Has(key)) {
return funcs_.at(key);
}
// If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one.
auto func = Get<KernelTuple, PlaceType>(attr);
// If do not have this attr in cache then get the default best
auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
Insert(key, func);
return func;
}
@ -156,7 +205,6 @@ class KernelFuncs {
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func);
}

@ -302,6 +302,7 @@ class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
virtual const char* ImplType() const = 0;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
@ -312,8 +313,8 @@ class KernelMore : public Kernel {
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
// specify this kernel can be used, means it should not fail if use it.
virtual bool CanBeUsed(const Attr& attr) const = 0;
protected:
Func func{nullptr};
@ -323,7 +324,7 @@ template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuple> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuple::attr_type& attr) const override {
bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }

@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
// XXH64: 13.8 GB/s
size_t key = attr.d;
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);

@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
}
bool CRFDecodingKernel::UseMe(const int& d) const {
bool CRFDecodingKernel::CanBeUsed(const int& d) const {
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else

@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(const typename CRFDecodingTuple<float>::attr_type&) const override;
bool CanBeUsed(
const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};

@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
}
bool LayerNormKernel::UseMe(const int& d) const {
bool LayerNormKernel::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
}

@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public:
LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuple<float>::attr_type&) const override;
bool CanBeUsed(
const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};

@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
// TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; }
bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; }
bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
bool SoftmaxKernel::UseMe(const int& d) const { return true; }
bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix
} // namespace more

@ -34,12 +34,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
#define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
// XYN

@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
bool VMulKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VAddKernel<float>::UseMe(const int& d) const {
bool VAddKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(const int& d) const {
bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::UseMe(const int& d) const {
bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
bool VSquareKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VCopyKernel<float>::UseMe(const int& d) const {
bool VCopyKernel<float>::CanBeUsed(const int& d) const {
return d > 15;
}
template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const {
bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
return d > 127;
}
template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const {
bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
return true;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VTanhKernel<float>::UseMe(const int& d) const {
bool VTanhKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
bool EmbSeqPoolKernel<double>::CanBeUsed(
const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const {
bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const {
bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
return true;
}
template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx);
}
template <>
bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const {
bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
return true;
}
template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const {
bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
// tuned on avx2
return platform::MayIUse(platform::avx) && d < 60;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
return true; \
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::CanBeUsed(const int& d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(VMul);

@ -175,13 +175,13 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
#define DECLARE_MKL_KERNEL(name) \
template <typename T> \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
#define DECLARE_MKL_KERNEL(name) \
template <typename T> \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
// ABCMNK

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save