|
|
|
@ -67,7 +67,7 @@ class VActFunc : public JitCode {
|
|
|
|
|
virtual void genCode() = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// compute relu with ymm, xmm
|
|
|
|
|
// compute RELU with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
|
|
|
|
|
JMM zero = JMM(zero_idx);
|
|
|
|
@ -75,7 +75,7 @@ class VActFunc : public JitCode {
|
|
|
|
|
vmaxps(dst, src, zero);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute exp with ymm, xmm
|
|
|
|
|
// compute EXP with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
|
|
|
|
|
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
|
|
|
|
@ -159,7 +159,7 @@ class VActFunc : public JitCode {
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute sigmoid with ymm, xmm
|
|
|
|
|
// compute SIGMOID with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
|
|
|
|
|
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
|
|
|
|
@ -184,7 +184,7 @@ class VActFunc : public JitCode {
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute tanh with ymm, xmm
|
|
|
|
|
// compute TANH with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
|
|
|
|
|
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
|
|
|
|
@ -211,7 +211,7 @@ class VActFunc : public JitCode {
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute identity with ymm, xmm
|
|
|
|
|
// compute IDENTITY with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT
|
|
|
|
|
JMM zero = JMM(zero_idx);
|
|
|
|
@ -225,19 +225,19 @@ class VActFunc : public JitCode {
|
|
|
|
|
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
|
|
|
|
|
// use 11~15
|
|
|
|
|
switch (type) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
case operand_type::RELU:
|
|
|
|
|
relu_jmm<JMM>(dst, src, 15);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
case operand_type::EXP:
|
|
|
|
|
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
case operand_type::SIGMOID:
|
|
|
|
|
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
case operand_type::TANH:
|
|
|
|
|
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
case operand_type::IDENTITY:
|
|
|
|
|
identity_jmm<JMM>(dst, src, 15);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
@ -252,9 +252,9 @@ class VActJitCode : public VActFunc {
|
|
|
|
|
explicit VActJitCode(int d, operand_type type, size_t code_size,
|
|
|
|
|
void* code_ptr = nullptr)
|
|
|
|
|
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
|
|
|
|
|
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
|
|
|
|
|
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
|
|
|
|
|
type_ == operand_type::identity)) {
|
|
|
|
|
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
|
|
|
|
|
type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
|
|
|
|
|
type_ == operand_type::IDENTITY)) {
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
|
|
|
|
}
|
|
|
|
|
this->genCode();
|
|
|
|
@ -263,19 +263,19 @@ class VActJitCode : public VActFunc {
|
|
|
|
|
const char* name() const override {
|
|
|
|
|
std::string base = "VActJitCode";
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
case operand_type::RELU:
|
|
|
|
|
base += "_Relu";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
case operand_type::EXP:
|
|
|
|
|
base += "_Exp";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
case operand_type::SIGMOID:
|
|
|
|
|
base += "_Sigmoid";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
case operand_type::TANH:
|
|
|
|
|
base += "_Tanh";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
case operand_type::IDENTITY:
|
|
|
|
|
base += "_Identity";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
@ -305,11 +305,11 @@ class VActJitCode : public VActFunc {
|
|
|
|
|
: VActJitCode(d, op_type, code_size, code_ptr) {} \
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_ACT_JITCODE(VRelu, operand_type::relu);
|
|
|
|
|
DECLARE_ACT_JITCODE(VIdentity, operand_type::identity);
|
|
|
|
|
DECLARE_ACT_JITCODE(VExp, operand_type::exp);
|
|
|
|
|
DECLARE_ACT_JITCODE(VSigmoid, operand_type::sigmoid);
|
|
|
|
|
DECLARE_ACT_JITCODE(VTanh, operand_type::tanh);
|
|
|
|
|
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
|
|
|
|
|
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
|
|
|
|
|
DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
|
|
|
|
|
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
|
|
|
|
|
DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
|
|
|
|
|
|
|
|
|
|
#undef DECLARE_ACT_JITCODE
|
|
|
|
|
|
|
|
|
|