parent
95fb31285c
commit
d53c4756ad
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,90 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/jit_gen.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
namespace gen {
|
||||
|
||||
constexpr Xbyak::Operand::Code g_abi_regs[] = {
|
||||
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
|
||||
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
|
||||
|
||||
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
|
||||
|
||||
void JitCode::preCode() {
|
||||
for (int i = 0; i < num_g_abi_regs; ++i) {
|
||||
push(Xbyak::Reg64(g_abi_regs[i]));
|
||||
}
|
||||
if (platform::MayIUse(platform::avx512f)) {
|
||||
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
|
||||
}
|
||||
}
|
||||
|
||||
void JitCode::postCode() {
|
||||
for (int i = 0; i < num_g_abi_regs; ++i) {
|
||||
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
|
||||
}
|
||||
ret();
|
||||
}
|
||||
|
||||
void JitCode::dumpCode(const Xbyak::uint8 *code) const {
|
||||
if (code) {
|
||||
static int counter = 0;
|
||||
std::ostringstream filename;
|
||||
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
|
||||
counter++;
|
||||
std::ofstream fout(filename.str(), std::ios::out);
|
||||
if (fout.is_open()) {
|
||||
fout.write(reinterpret_cast<const char *>(code), getSize());
|
||||
fout.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Xbyak::Address JitCode::EVEX_compress_addr(Xbyak::Reg64 base, int offt,
|
||||
bool bcast) {
|
||||
int scale = 0;
|
||||
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
|
||||
offt = offt - 2 * EVEX_max_8b_offt;
|
||||
scale = 1;
|
||||
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
|
||||
offt = offt - 4 * EVEX_max_8b_offt;
|
||||
scale = 2;
|
||||
}
|
||||
auto re = Xbyak::RegExp() + base + offt;
|
||||
if (scale) {
|
||||
re = re + reg_EVEX_max_8b_offt * scale;
|
||||
}
|
||||
if (bcast) {
|
||||
return zword_b[re];
|
||||
} else {
|
||||
return zword[re];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,80 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <type_traits>
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
#define XBYAK_USE_MMAP_ALLOCATOR
|
||||
#include "xbyak/xbyak.h"
|
||||
#include "xbyak/xbyak_util.h"
|
||||
|
||||
DECLARE_bool(dump_jitcode);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
namespace gen {
|
||||
|
||||
#define DECLARE_JIT_CODE(codename) \
|
||||
const char *name() const override { return #codename; }
|
||||
|
||||
// Application Binary Interface
|
||||
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
|
||||
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
|
||||
abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX);
|
||||
|
||||
class JitCode : public Xbyak::CodeGenerator {
|
||||
public:
|
||||
explicit JitCode(size_t code_size = 256 * 1024, void *code_ptr = nullptr)
|
||||
: Xbyak::CodeGenerator(code_size, code_ptr) {}
|
||||
|
||||
virtual ~JitCode() {}
|
||||
virtual const char *name() const = 0;
|
||||
virtual void generate() = 0;
|
||||
|
||||
template <typename FUNC>
|
||||
const FUNC getCode() {
|
||||
this->generate();
|
||||
const Xbyak::uint8 *code = CodeGenerator::getCode();
|
||||
if (FLAGS_dump_jitcode) {
|
||||
this->dumpCode(code);
|
||||
}
|
||||
return reinterpret_cast<const FUNC>(code);
|
||||
}
|
||||
DISABLE_COPY_AND_ASSIGN(JitCode);
|
||||
|
||||
protected:
|
||||
Xbyak::Reg64 param1{abi_param1};
|
||||
const int EVEX_max_8b_offt = 0x200;
|
||||
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
|
||||
|
||||
void preCode();
|
||||
void postCode();
|
||||
void dumpCode(const Xbyak::uint8 *code) const;
|
||||
void L(const char *label) { Xbyak::CodeGenerator::L(label); }
|
||||
void L(const Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); }
|
||||
// Enhanced vector extension
|
||||
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
|
||||
bool bcast = false);
|
||||
};
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,39 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
KernelPool& KernelPool::Instance() {
|
||||
static thread_local KernelPool g_jit_kernels;
|
||||
return g_jit_kernels;
|
||||
}
|
||||
|
||||
std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
|
||||
if (kers_.find(key) == kers_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return kers_.at(key);
|
||||
}
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,157 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
// Note: Only support on CPU yet.
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
// TODO(TJ): remove me
|
||||
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() = default;
|
||||
virtual ~Kernel() = default;
|
||||
// TODO(TJ): below members should be deprecated.
|
||||
int num_{0};
|
||||
int end_{0};
|
||||
int rest_{0};
|
||||
DISABLE_COPY_AND_ASSIGN(Kernel);
|
||||
};
|
||||
|
||||
class KernelPool {
|
||||
public:
|
||||
static KernelPool &Instance();
|
||||
|
||||
template <typename Ker, typename... ARGS>
|
||||
std::shared_ptr<const Ker> Get(ARGS... args);
|
||||
|
||||
std::shared_ptr<const Kernel> Get(const std::string &key) const;
|
||||
|
||||
private:
|
||||
KernelPool() = default;
|
||||
std::unordered_map<std::string, std::shared_ptr<const Kernel>> kers_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(KernelPool);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VMulKernel : public Kernel {
|
||||
public:
|
||||
void (*Compute)(const T *, const T *, T *, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddKernel : public Kernel {
|
||||
public:
|
||||
void (*Compute)(const T *, const T *, T *, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddReluKernel : public Kernel {
|
||||
public:
|
||||
void (*Compute)(const T *, const T *, T *, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VScalKernel : public Kernel {
|
||||
public:
|
||||
// y = a.*x
|
||||
void (*Compute)(const T *, const T *, T *, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddBiasKernel : public Kernel {
|
||||
public:
|
||||
// y = a.+x
|
||||
void (*Compute)(const T *, const T *, T *, int);
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
template <typename T>
|
||||
class EltwiseMulnChw16cNCKernel : public Kernel {
|
||||
public:
|
||||
// nChw16c = nChw16c .* NC
|
||||
void (*Compute)(const float *, const float *, float *, int, int);
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
class VActKernel : public Kernel {
|
||||
public:
|
||||
void (*Compute)(const T *, T *, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VReluKernel : public VActKernel<T> {};
|
||||
|
||||
template <typename T>
|
||||
class VIdentityKernel : public VActKernel<T> {};
|
||||
|
||||
template <typename T>
|
||||
class VExpKernel : public VActKernel<T> {};
|
||||
|
||||
template <typename T>
|
||||
class VSigmoidKernel : public VActKernel<T> {};
|
||||
|
||||
template <typename T>
|
||||
class VTanhKernel : public VActKernel<T> {};
|
||||
|
||||
template <typename T>
|
||||
class LSTMKernel : public Kernel {
|
||||
public:
|
||||
// compute c1 and h1 without c0 or h0
|
||||
void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
|
||||
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class GRUKernel : public Kernel {
|
||||
public:
|
||||
// compute h1 without h0
|
||||
void (*ComputeH1)(gru_t *, const gru_attr_t *);
|
||||
void (*ComputeHtPart1)(gru_t *, const gru_attr_t *);
|
||||
void (*ComputeHtPart2)(gru_t *, const gru_attr_t *);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CRFDecodeKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha,
|
||||
int *track) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LayerNormKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale,
|
||||
const T *bias, int height,
|
||||
const float epsilon) const = 0;
|
||||
};
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,195 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
|
||||
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
#include "paddle/fluid/operators/math/jit_code.h"
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
#include "paddle/fluid/platform/dynload/mklml.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
/* VExp JitKernel */
|
||||
template <typename T>
|
||||
class VExpKernelImpl : public VExpKernel<T> {
|
||||
public:
|
||||
JITKERNEL_DECLARE_STATIC_FUNC;
|
||||
explicit VExpKernelImpl(int d) : VExpKernel<T>() {
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
if (useJIT(d)) {
|
||||
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8;
|
||||
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp,
|
||||
sz > 4096 ? sz : 4096));
|
||||
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
if (useMKL(d)) {
|
||||
this->Compute = VExpMKL<T>;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
this->Compute = refer::VExp<T>;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
|
||||
private:
|
||||
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
template <>
|
||||
bool VExpKernelImpl<float>::useJIT(int d) {
|
||||
return gen::VActJitCode::init(d, gen::operand_type::exp);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
template <>
|
||||
bool VExpKernelImpl<float>::useMKL(int d) {
|
||||
return d > 512;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool VExpKernelImpl<double>::useMKL(int d) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/* VSigmoid JitKernel */
|
||||
template <typename T>
|
||||
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
|
||||
public:
|
||||
JITKERNEL_DECLARE_STATIC_FUNC;
|
||||
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
if (useJIT(d)) {
|
||||
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8;
|
||||
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid,
|
||||
sz > 4096 ? sz : 4096));
|
||||
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
// strictly it's a better impl with MKL, then is refer
|
||||
if (useMKL(d)) {
|
||||
this->Compute = VSigmoidMKL<T>;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
this->Compute = refer::VSigmoid<T>;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
|
||||
private:
|
||||
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
template <>
|
||||
bool VSigmoidKernelImpl<float>::useJIT(int d) {
|
||||
return gen::VActJitCode::init(d, gen::operand_type::sigmoid);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
template <>
|
||||
bool VSigmoidKernelImpl<float>::useMKL(int d) {
|
||||
return d > 512;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool VSigmoidKernelImpl<double>::useMKL(int d) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* VTanh JitKernel */
|
||||
template <typename T>
|
||||
class VTanhKernelImpl : public VTanhKernel<T> {
|
||||
public:
|
||||
JITKERNEL_DECLARE_STATIC_FUNC;
|
||||
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
if (useJIT(d)) {
|
||||
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8;
|
||||
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh,
|
||||
sz > 4096 ? sz : 4096));
|
||||
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
// strictly it's a better impl with MKL, then is refer
|
||||
if (useMKL(d)) {
|
||||
this->Compute = VTanhMKL<T>;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
this->Compute = refer::VTanh<T>;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
|
||||
private:
|
||||
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
template <>
|
||||
bool VTanhKernelImpl<float>::useJIT(int d) {
|
||||
return gen::VActJitCode::init(d, gen::operand_type::tanh);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
template <>
|
||||
bool VTanhKernelImpl<float>::useMKL(int d) {
|
||||
return d > 512;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool VTanhKernelImpl<double>::useMKL(int d) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
REGISTER_JITKERNEL(vexp, VExpKernel);
|
||||
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
|
||||
REGISTER_JITKERNEL(vtanh, VTanhKernel);
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,34 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
#define SIGMOID_THRESHOLD_MIN -40.0
|
||||
#define SIGMOID_THRESHOLD_MAX 13.0
|
||||
#define EXP_MAX_INPUT 40.0
|
||||
#define XMM_FLOAT_BLOCK 4
|
||||
#define YMM_FLOAT_BLOCK 8
|
||||
#define ZMM_FLOAT_BLOCK 16
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,239 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include <math.h>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
/* Layer Norm JitKernel */
|
||||
template <typename T, platform::cpu_isa_t isa, jit_block>
|
||||
class LayerNormKernelImpl : public LayerNormKernel<T> {
|
||||
public:
|
||||
explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
|
||||
this->num_ = right;
|
||||
}
|
||||
|
||||
void Compute(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
|
||||
int height, const float epsilon) const override {
|
||||
// get mean
|
||||
for (int i = 0; i < height; i++) {
|
||||
T sum = 0.0;
|
||||
int offset = i * this->num_;
|
||||
for (int j = 0; j < this->num_; j++) {
|
||||
sum += x[offset + j];
|
||||
}
|
||||
mean[i] = sum / this->num_;
|
||||
}
|
||||
|
||||
// get variance
|
||||
for (int i = 0; i < height; i++) {
|
||||
T sum = 0.0;
|
||||
int offset = i * this->num_;
|
||||
for (int j = 0; j < this->num_; j++) {
|
||||
sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]);
|
||||
}
|
||||
var[i] = sum / this->num_;
|
||||
}
|
||||
|
||||
for (int i = 0; i < height; i++) {
|
||||
int offset = i * this->num_;
|
||||
T sqrt_var = sqrt(var[i] + (T)epsilon);
|
||||
for (int j = 0; j < this->num_; j++) {
|
||||
out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var;
|
||||
}
|
||||
}
|
||||
if (scale) {
|
||||
for (int i = 0; i < height; i++) {
|
||||
int offset = i * this->num_;
|
||||
for (int j = 0; j < this->num_; j++) {
|
||||
out[offset + j] *= scale[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bias) {
|
||||
for (int i = 0; i < height; i++) {
|
||||
int offset = i * this->num_;
|
||||
for (int j = 0; j < this->num_; j++) {
|
||||
out[offset + j] += bias[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define INTRIAVX_FLOAT(isa, jit_block) \
|
||||
template <> \
|
||||
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
|
||||
: LayerNormKernel<float>() { \
|
||||
this->num_ = right; \
|
||||
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
|
||||
this->end_ = this->num_ - this->rest_; \
|
||||
} \
|
||||
template <> \
|
||||
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
|
||||
float* x, float* out, float* mean, float* var, const float* scale, \
|
||||
const float* bias, int height, const float epsilon) const { \
|
||||
__m256 sum; \
|
||||
__m256 mean_vec, var_vec; \
|
||||
__m128 hi, lo; \
|
||||
__m256 tmp; \
|
||||
size_t offset; \
|
||||
size_t j; \
|
||||
size_t block = YMM_FLOAT_BLOCK; \
|
||||
__m256 reverse_num_vec = \
|
||||
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
|
||||
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
|
||||
int rest_mask = \
|
||||
((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \
|
||||
0x0ff; \
|
||||
__m256i mask_vec = _mm256_set_epi32( \
|
||||
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \
|
||||
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \
|
||||
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \
|
||||
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \
|
||||
\
|
||||
for (int i = 0; i < height; ++i) { \
|
||||
offset = i * this->num_; \
|
||||
\
|
||||
/* get mean */ \
|
||||
sum = _mm256_setzero_ps(); \
|
||||
for (j = offset; j < end_ + offset; j += block) { \
|
||||
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \
|
||||
} \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + this->num_ - block; \
|
||||
tmp = _mm256_loadu_ps((const float*)x + j); \
|
||||
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
|
||||
sum = _mm256_add_ps(sum, tmp); \
|
||||
} \
|
||||
hi = _mm256_extractf128_ps(sum, 1); \
|
||||
lo = _mm256_extractf128_ps(sum, 0); \
|
||||
sum = _mm256_add_ps( \
|
||||
sum, _mm256_insertf128_ps( \
|
||||
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
|
||||
sum = _mm256_hadd_ps(sum, sum); \
|
||||
sum = _mm256_hadd_ps(sum, sum); \
|
||||
mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \
|
||||
mean[i] = *reinterpret_cast<float*>(&mean_vec); \
|
||||
\
|
||||
/* get variance */ \
|
||||
sum = _mm256_setzero_ps(); \
|
||||
for (j = offset; j < end_ + offset; j += block) { \
|
||||
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
|
||||
tmp = _mm256_mul_ps(tmp, tmp); \
|
||||
sum = _mm256_add_ps(sum, tmp); \
|
||||
} \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + this->num_ - block; \
|
||||
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
|
||||
tmp = _mm256_mul_ps(tmp, tmp); \
|
||||
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
|
||||
sum = _mm256_add_ps(sum, tmp); \
|
||||
} \
|
||||
hi = _mm256_extractf128_ps(sum, 1); \
|
||||
lo = _mm256_extractf128_ps(sum, 0); \
|
||||
sum = _mm256_add_ps( \
|
||||
sum, _mm256_insertf128_ps( \
|
||||
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
|
||||
sum = _mm256_hadd_ps(sum, sum); \
|
||||
sum = _mm256_hadd_ps(sum, sum); \
|
||||
var_vec = _mm256_mul_ps(sum, reverse_num_vec); \
|
||||
var[i] = *reinterpret_cast<float*>(&var_vec); \
|
||||
\
|
||||
/* get x_norm and calculate output*/ \
|
||||
for (j = offset; j < end_ + offset; j += block) { \
|
||||
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
|
||||
tmp = _mm256_div_ps( \
|
||||
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
|
||||
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
|
||||
} \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + num_ - block; \
|
||||
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
|
||||
tmp = _mm256_div_ps( \
|
||||
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
|
||||
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
|
||||
} \
|
||||
\
|
||||
if (scale) { \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + this->num_ - block; \
|
||||
tmp = _mm256_loadu_ps((const float*)out + j); \
|
||||
} \
|
||||
for (j = offset; j < end_ + offset; j += block) { \
|
||||
_mm256_storeu_ps( \
|
||||
reinterpret_cast<float*>(out) + j, \
|
||||
_mm256_mul_ps( \
|
||||
_mm256_loadu_ps((const float*)out + j), \
|
||||
_mm256_loadu_ps((const float*)scale + j - offset))); \
|
||||
} \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + this->num_ - block; \
|
||||
_mm256_storeu_ps( \
|
||||
reinterpret_cast<float*>(out) + j, \
|
||||
_mm256_mul_ps( \
|
||||
tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if (bias) { \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + this->num_ - block; \
|
||||
tmp = _mm256_loadu_ps((const float*)out + j); \
|
||||
} \
|
||||
for (j = offset; j < end_ + offset; j += block) { \
|
||||
_mm256_storeu_ps( \
|
||||
reinterpret_cast<float*>(out) + j, \
|
||||
_mm256_add_ps( \
|
||||
_mm256_loadu_ps((const float*)out + j), \
|
||||
_mm256_loadu_ps((const float*)bias + j - offset))); \
|
||||
} \
|
||||
if (rest_ != 0) { \
|
||||
j = offset + this->num_ - block; \
|
||||
_mm256_storeu_ps( \
|
||||
reinterpret_cast<float*>(out) + j, \
|
||||
_mm256_add_ps( \
|
||||
tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
#ifdef __AVX__
|
||||
INTRIAVX_FLOAT(platform::avx, kEQ8);
|
||||
INTRIAVX_FLOAT(platform::avx, kGT8LT16);
|
||||
INTRIAVX_FLOAT(platform::avx, kEQ16);
|
||||
INTRIAVX_FLOAT(platform::avx, kGT16);
|
||||
INTRIAVX_FLOAT(platform::avx2, kEQ8);
|
||||
INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
|
||||
INTRIAVX_FLOAT(platform::avx2, kEQ16);
|
||||
INTRIAVX_FLOAT(platform::avx2, kGT16);
|
||||
INTRIAVX_FLOAT(platform::avx512f, kEQ8);
|
||||
INTRIAVX_FLOAT(platform::avx512f, kGT8LT16);
|
||||
INTRIAVX_FLOAT(platform::avx512f, kEQ16);
|
||||
INTRIAVX_FLOAT(platform::avx512f, kGT16);
|
||||
#endif
|
||||
|
||||
#undef INTRIAVX_FLOAT
|
||||
|
||||
REGISTER_JITKERNEL_DEPRECATED(layer_norm, LayerNormKernel);
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,179 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
#define JITKERNEL_DECLARE_STATIC_FUNC \
|
||||
static inline std::string name(int d) { \
|
||||
PADDLE_THROW("DType should be either float or double"); \
|
||||
} \
|
||||
static inline bool useJIT(int d) { return false; } \
|
||||
static inline bool useMKL(int d) { return false; }
|
||||
|
||||
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
|
||||
template <> \
|
||||
std::string ker_class##Impl<float>::name(int d) { \
|
||||
std::string key(#ker_key "f"); \
|
||||
if (useJIT(d)) { \
|
||||
/* only jit code need record d*/ \
|
||||
return key + "jit" + std::to_string(d); \
|
||||
} else if (useMKL(d)) { \
|
||||
return key + "mkl"; \
|
||||
} else { \
|
||||
return key + "any"; \
|
||||
} \
|
||||
} \
|
||||
template <> \
|
||||
std::string ker_class##Impl<double>::name(int d) { \
|
||||
std::string key(#ker_key "d"); \
|
||||
/* jit code do not support double yet*/ \
|
||||
if (useMKL(d)) { \
|
||||
return key + "mkl"; \
|
||||
} else { \
|
||||
return key + "any"; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
|
||||
template <> \
|
||||
std::shared_ptr<const ker_class<ker_dtype>> \
|
||||
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
|
||||
|
||||
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
|
||||
std::string key = ker_class##Impl<ker_dtype>::name(d)
|
||||
|
||||
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
|
||||
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
|
||||
std::make_shared<ker_class##Impl<ker_dtype>>(d))
|
||||
|
||||
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
|
||||
macro_find_key, macro_impl) \
|
||||
marco_declare(ker_class, ker_dtype) { \
|
||||
macro_find_key(ker_class, ker_dtype); \
|
||||
if (kers_.find(key) == kers_.end()) { \
|
||||
std::shared_ptr<ker_class<ker_dtype>> p; \
|
||||
macro_impl(ker_class, ker_dtype); \
|
||||
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
|
||||
return p; \
|
||||
} \
|
||||
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
|
||||
kers_.at(key)); \
|
||||
}
|
||||
|
||||
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
|
||||
marco_declare, macro_find_key, macro_impl) \
|
||||
marco_define_name(ker_key, ker_class); \
|
||||
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
|
||||
macro_find_key, macro_impl); \
|
||||
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
|
||||
macro_find_key, macro_impl)
|
||||
|
||||
#define REGISTER_JITKERNEL(ker_key, ker_class) \
|
||||
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
|
||||
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
|
||||
JITKERNEL_IMPL)
|
||||
|
||||
// TODO(TJ): below defines are deprecated, would be remove recently
|
||||
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
|
||||
if (d < YMM_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kLT8); \
|
||||
} else if (d == YMM_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kEQ8); \
|
||||
} else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kGT8LT16); \
|
||||
} else if (d == ZMM_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kEQ16); \
|
||||
} else { \
|
||||
macro_(ker, dtype, isa, kGT16); \
|
||||
}
|
||||
|
||||
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
|
||||
if (platform::MayIUse(platform::avx512f)) { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
|
||||
} else if (platform::MayIUse(platform::avx2)) { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \
|
||||
} else if (platform::MayIUse(platform::avx)) { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \
|
||||
} else { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
|
||||
}
|
||||
|
||||
#define JITKERNEL_KEY(ker_key, dtype_key) \
|
||||
#ker_key #dtype_key + std::to_string(d)
|
||||
|
||||
#define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \
|
||||
p = std::dynamic_pointer_cast<ker<dtype>>( \
|
||||
std::make_shared<ker##Impl<dtype, isa, k>>(d))
|
||||
|
||||
#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
|
||||
dtype_key, marco_declare, macro_key, \
|
||||
macro_impl) \
|
||||
marco_declare(ker_class, ker_dtype) { \
|
||||
std::string key = macro_key(ker_key, dtype_key); \
|
||||
if (kers_.find(key) == kers_.end()) { \
|
||||
std::shared_ptr<ker_class<ker_dtype>> p; \
|
||||
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
|
||||
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
|
||||
return p; \
|
||||
} \
|
||||
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
|
||||
kers_.at(key)); \
|
||||
}
|
||||
|
||||
#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
|
||||
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
|
||||
JITKERNEL_DECLARE, JITKERNEL_KEY, \
|
||||
JITKERNEL_NEW_IMPL_DEPRECATED); \
|
||||
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
|
||||
JITKERNEL_DECLARE, JITKERNEL_KEY, \
|
||||
JITKERNEL_NEW_IMPL_DEPRECATED)
|
||||
|
||||
#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
|
||||
macro_key, macro_impl) \
|
||||
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
|
||||
macro_key, macro_impl); \
|
||||
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
|
||||
marco_declare, macro_key, macro_impl)
|
||||
|
||||
#define FOR_EACH_ISA(macro_, block) \
|
||||
macro_(platform::avx512f, block); \
|
||||
macro_(platform::avx2, block); \
|
||||
macro_(platform::avx, block); \
|
||||
macro_(platform::isa_any, block)
|
||||
|
||||
#define FOR_EACH_BLOCK(macro_, isa) \
|
||||
macro_(isa, kLT8); \
|
||||
macro_(isa, kEQ8); \
|
||||
macro_(isa, kGT8LT16); \
|
||||
macro_(isa, kEQ16); \
|
||||
macro_(isa, kGT16)
|
||||
|
||||
#define FOR_EACH_ISA_BLOCK(macro_) \
|
||||
FOR_EACH_BLOCK(macro_, platform::avx512f); \
|
||||
FOR_EACH_BLOCK(macro_, platform::avx2); \
|
||||
FOR_EACH_BLOCK(macro_, platform::avx); \
|
||||
FOR_EACH_BLOCK(macro_, platform::isa_any)
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue