|
|
|
@ -75,6 +75,12 @@ class VActFunc : public JitCode {
|
|
|
|
|
vmaxps(dst, src, zero);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute SQUARE with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void square_jmm(JMM& dst, JMM& src) { // NOLINT
|
|
|
|
|
vmulps(dst, src, src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute EXP with ymm, xmm
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
|
|
|
|
@ -228,6 +234,9 @@ class VActFunc : public JitCode {
|
|
|
|
|
case operand_type::RELU:
|
|
|
|
|
relu_jmm<JMM>(dst, src, 15);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::SQUARE:
|
|
|
|
|
square_jmm<JMM>(dst, src);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::EXP:
|
|
|
|
|
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
|
|
|
|
|
break;
|
|
|
|
@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
|
|
|
|
|
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
|
|
|
|
|
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
|
|
|
|
|
type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
|
|
|
|
|
type_ == operand_type::IDENTITY)) {
|
|
|
|
|
type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) {
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
|
|
|
|
}
|
|
|
|
|
this->genCode();
|
|
|
|
@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
|
|
|
|
|
case operand_type::RELU:
|
|
|
|
|
base += "_Relu";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::SQUARE:
|
|
|
|
|
base += "_Square";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::EXP:
|
|
|
|
|
base += "_Exp";
|
|
|
|
|
break;
|
|
|
|
@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
|
|
|
|
|
DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE);
|
|
|
|
|
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
|
|
|
|
|
DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
|
|
|
|
|
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
|
|
|
|
|