Merge pull request #14753 from tensor-tang/refine/namespace

remove jit namespace
revert-14398-imperative
tensor-tang 6 years ago committed by GitHub
commit dbe451976b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -231,10 +231,10 @@ use lstm_x_t as input and compute as standard LSTM.
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) {
math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
math::vec_relu<T, platform::jit::avx>(n, y, y);
math::vec_add_bias<T, platform::avx>(n, *bias, x, y);
math::vec_relu<T, platform::avx>(n, y, y);
} else {
math::vec_relu<T, platform::jit::avx>(n, x, y);
math::vec_relu<T, platform::avx>(n, x, y);
}
}
@ -245,8 +245,8 @@ inline void vec_softmax(const int n, const T* x, T* y) {
for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar;
}
math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub
math::vec_exp<T>(n, y, y); // exp
math::vec_add_bias<T, platform::avx>(n, -scalar, x, y); // sub
math::vec_exp<T>(n, y, y); // exp
// sum
scalar = T(0);
for (int i = 0; i < n; ++i) {
@ -302,13 +302,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
if (platform::MayIUse(platform::avx)) {
math::VecActivations<T, platform::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
math::VecActivations<T, platform::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);

@ -217,13 +217,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
if (platform::MayIUse(platform::avx)) { \
math::VecActivations<T, platform::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
math::VecActivations<T, platform::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \

@ -151,11 +151,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
std::function<void(const int, const T*, T*)> fc_act;
auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
if (platform::MayIUse(platform::avx)) {
math::VecActivations<T, platform::avx> act_functor;
fc_act = act_functor(fc_act_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
math::VecActivations<T, platform::isa_any> act_functor;
fc_act = act_functor(fc_act_str);
}

File diff suppressed because it is too large Load Diff

@ -104,38 +104,42 @@ void TestAndBench(const int n, std::function<void(const int, const T*, T*)> tgt,
}
TEST(CpuVecTest, sigmoid) {
namespace jit = paddle::platform::jit;
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512f>,
TestAndBench<float>(sz, vec_sigmoid<float, platform::avx>,
ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, platform::avx2>,
ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, platform::avx512f>,
ref_sigmoid<float>);
}
TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
}
TEST(CpuVecTest, tanh) {
namespace jit = paddle::platform::jit;
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, platform::avx512f>,
ref_tanh<float>);
}
TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
}
TEST(CpuVecTest, relu) {
namespace jit = paddle::platform::jit;
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, platform::avx512f>,
ref_relu<float>);
}
TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
}
@ -162,38 +166,40 @@ void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt,
}
TEST(CpuVecTest, inplace_sigmoid) {
namespace jit = paddle::platform::jit;
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx512f>,
TestInplace<float>(sz, vec_sigmoid<float, platform::avx>,
ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, platform::avx2>,
ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, platform::avx512f>,
ref_sigmoid<float>);
}
TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
}
TEST(CpuVecTest, inplace_tanh) {
namespace jit = paddle::platform::jit;
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, platform::avx512f>, ref_tanh<float>);
}
TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
}
TEST(CpuVecTest, inplace_relu) {
namespace jit = paddle::platform::jit;
namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, platform::avx512f>, ref_relu<float>);
}
TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
}

@ -22,7 +22,7 @@ namespace math {
namespace jitkernel {
namespace gen {
using namespace platform::jit; // NOLINT
using namespace platform; // NOLINT
bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency

@ -179,7 +179,7 @@ class VActJitCode : public JitCode {
template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
using namespace platform::jit; // NOLINT
using namespace platform; // NOLINT
// check all idx can not equal
JMM jmm_src = JMM(src_idx);
JMM jmm_fx = JMM(fx_idx);

@ -36,7 +36,7 @@ void JitCode::preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::jit::MayIUse(platform::jit::avx512f)) {
if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}

@ -21,8 +21,6 @@ namespace operators {
namespace math {
namespace jitkernel {
namespace jit = platform::jit;
KernelPool& KernelPool::Instance() {
static thread_local KernelPool g_jit_kernels;
return g_jit_kernels;

@ -30,7 +30,6 @@ namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace jit = platform::jit;
#ifdef PADDLE_WITH_MKLML
template <typename T>
@ -125,7 +124,7 @@ bool VMulKernelImpl<float>::useJIT(int d) {
#ifdef PADDLE_WITH_MKLML
template <>
bool VMulKernelImpl<float>::useMKL(int d) {
return jit::MayIUse(jit::avx512f) && d > 512;
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>

@ -25,10 +25,8 @@ namespace operators {
namespace math {
namespace jitkernel {
namespace jit = platform::jit;
/* CRF Decode JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
template <typename T, platform::cpu_isa_t isa, jit_block>
class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
public:
explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel<T>() {
@ -101,7 +99,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
#define INTRIAVX_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, jit::avx, block>::CRFDecodeKernelImpl( \
CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
@ -109,7 +107,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
void CRFDecodeKernelImpl<float, platform::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
@ -204,7 +202,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
#define INTRIAVX512_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, jit::avx512f, block>::CRFDecodeKernelImpl( \
CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
@ -212,7 +210,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(ZMM_FLOAT_BLOCK) \
@ -270,14 +268,14 @@ INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT(kGT16);
#endif
#ifdef __AVX2__
INTRIAVX2_FLOAT(jit::avx2, kEQ8);
INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ16);
INTRIAVX2_FLOAT(jit::avx2, kGT16);
INTRIAVX2_FLOAT(platform::avx2, kEQ8);
INTRIAVX2_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX2_FLOAT(platform::avx2, kEQ16);
INTRIAVX2_FLOAT(platform::avx2, kGT16);
#endif
#ifdef __AVX512F__
INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
INTRIAVX2_FLOAT(platform::avx512f, kEQ8);
INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX512_FLOAT(kEQ16);
INTRIAVX512_FLOAT(kGT16);
#endif

@ -29,7 +29,6 @@ namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace jit = platform::jit;
#ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup

@ -22,10 +22,8 @@ namespace operators {
namespace math {
namespace jitkernel {
namespace jit = platform::jit;
/* Layer Norm JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
template <typename T, platform::cpu_isa_t isa, jit_block>
class LayerNormKernelImpl : public LayerNormKernel<T> {
public:
explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
@ -90,7 +88,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
this->end_ = this->num_ - this->rest_; \
} \
template <> \
void LayerNormKernelImpl<float, jit::avx, block>::Compute( \
void LayerNormKernelImpl<float, platform::avx, block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \
__m256 sum; \
@ -219,16 +217,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
}
#ifdef __AVX__
INTRIAVX_FLOAT(jit::avx, kEQ8);
INTRIAVX_FLOAT(jit::avx, kGT8LT16);
INTRIAVX_FLOAT(jit::avx, kEQ16);
INTRIAVX_FLOAT(jit::avx, kGT16);
INTRIAVX_FLOAT(platform::avx, kEQ8);
INTRIAVX_FLOAT(platform::avx, kGT8LT16);
INTRIAVX_FLOAT(platform::avx, kEQ16);
INTRIAVX_FLOAT(platform::avx, kGT16);
#endif
#ifdef __AVX2__
INTRIAVX_FLOAT(jit::avx2, kEQ8);
INTRIAVX_FLOAT(jit::avx2, kGT8LT16);
INTRIAVX_FLOAT(jit::avx2, kEQ16);
INTRIAVX_FLOAT(jit::avx2, kGT16);
INTRIAVX_FLOAT(platform::avx2, kEQ8);
INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX_FLOAT(platform::avx2, kEQ16);
INTRIAVX_FLOAT(platform::avx2, kGT16);
#endif
#undef INTRIAVX_FLOAT

@ -92,7 +92,6 @@ namespace jitkernel {
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
namespace jit = platform::jit;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < YMM_FLOAT_BLOCK) { \
@ -107,15 +106,15 @@ namespace jit = platform::jit;
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (jit::MayIUse(jit::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \
} else { \
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (platform::MayIUse(platform::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
} else if (platform::MayIUse(platform::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \
} else if (platform::MayIUse(platform::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \
} else { \
SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
}
#define JITKERNEL_KEY(ker_key, dtype_key) \
@ -156,10 +155,10 @@ namespace jit = platform::jit;
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
macro_(jit::avx2, block); \
macro_(jit::avx, block); \
macro_(jit::isa_any, block)
macro_(platform::avx512f, block); \
macro_(platform::avx2, block); \
macro_(platform::avx, block); \
macro_(platform::isa_any, block)
#define FOR_EACH_BLOCK(macro_, isa) \
macro_(isa, kLT8); \
@ -168,11 +167,11 @@ namespace jit = platform::jit;
macro_(isa, kEQ16); \
macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, jit::avx512f); \
FOR_EACH_BLOCK(macro_, jit::avx2); \
FOR_EACH_BLOCK(macro_, jit::avx); \
FOR_EACH_BLOCK(macro_, jit::isa_any)
#define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, platform::avx512f); \
FOR_EACH_BLOCK(macro_, platform::avx2); \
FOR_EACH_BLOCK(macro_, platform::avx); \
FOR_EACH_BLOCK(macro_, platform::isa_any)
} // namespace jitkernel
} // namespace math

@ -705,7 +705,7 @@ TEST(JitKernel, pool) {
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle::platform::jit::MayIUse(paddle::platform::jit::avx);
paddle::platform::MayIUse(paddle::platform::avx);
const auto& plstm1 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);

@ -123,7 +123,6 @@ size_t CUDAPinnedMaxChunkSize() {
return CUDAPinnedMaxAllocSize() / 256;
}
namespace jit {
#ifdef PADDLE_WITH_XBYAK
static Xbyak::util::Cpu cpu;
bool MayIUse(const cpu_isa_t cpu_isa) {
@ -165,6 +164,5 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
}
#endif
} // namespace jit
} // namespace platform
} // namespace paddle

@ -39,7 +39,6 @@ size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator.
size_t CUDAPinnedMaxChunkSize();
namespace jit {
typedef enum {
isa_any,
sse42,
@ -55,7 +54,5 @@ typedef enum {
// May I use some instruction
bool MayIUse(const cpu_isa_t cpu_isa);
} // namespace jit
} // namespace platform
} // namespace paddle

@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif
#if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__)
if (platform::jit::MayIUse(platform::jit::avx)) {
if (platform::MayIUse(platform::avx)) {
#ifndef __AVX__
LOG(WARNING) << "AVX is available, Please re-compile on local machine";
#endif
@ -131,10 +131,10 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
" version or compile from source code."
#ifdef __AVX512F__
if (!platform::jit::MayIUse(platform::jit::avx512f)) {
if (platform::jit::MayIUse(platform::jit::avx2)) {
if (!platform::MayIUse(platform::avx512f)) {
if (platform::MayIUse(platform::avx2)) {
AVX_GUIDE(AVX512, AVX2);
} else if (platform::jit::MayIUse(platform::jit::avx)) {
} else if (platform::MayIUse(platform::avx)) {
AVX_GUIDE(AVX512, AVX);
} else {
AVX_GUIDE(AVX512, NonAVX);
@ -143,8 +143,8 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif
#ifdef __AVX2__
if (!platform::jit::MayIUse(platform::jit::avx2)) {
if (platform::jit::MayIUse(platform::jit::avx)) {
if (!platform::MayIUse(platform::avx2)) {
if (platform::MayIUse(platform::avx)) {
AVX_GUIDE(AVX2, AVX);
} else {
AVX_GUIDE(AVX2, NonAVX);
@ -153,7 +153,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif
#ifdef __AVX__
if (!platform::jit::MayIUse(platform::jit::avx)) {
if (!platform::MayIUse(platform::avx)) {
AVX_GUIDE(AVX, NonAVX);
}
#endif

Loading…
Cancel
Save