|
|
|
@ -59,43 +59,12 @@ extern int g_tmp_mem[];
|
|
|
|
|
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
|
|
|
|
|
class VActJitCode : public JitCode {
|
|
|
|
|
class VActFunc : public JitCode {
|
|
|
|
|
public:
|
|
|
|
|
explicit VActJitCode(int d, operand_type type, size_t code_size,
|
|
|
|
|
void* code_ptr = nullptr)
|
|
|
|
|
: JitCode(code_size, code_ptr), num_(d), type_(type) {
|
|
|
|
|
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
|
|
|
|
|
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
|
|
|
|
|
type_ == operand_type::identity)) {
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
|
|
|
|
}
|
|
|
|
|
this->genCode();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char* name() const override {
|
|
|
|
|
std::string base = "VActJitCode";
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
base += "_Relu";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
base += "_Exp";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
base += "_Sigmoid";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
base += "_Tanh";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
base += "_Identity";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
return base.c_str();
|
|
|
|
|
}
|
|
|
|
|
void genCode() override;
|
|
|
|
|
explicit VActFunc(size_t code_size, void* code_ptr)
|
|
|
|
|
: JitCode(code_size, code_ptr) {}
|
|
|
|
|
virtual const char* name() const = 0;
|
|
|
|
|
virtual void genCode() = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// compute relu with ymm, xmm
|
|
|
|
@ -272,10 +241,49 @@ class VActJitCode : public JitCode {
|
|
|
|
|
identity_jmm<JMM>(dst, src, 15);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class VActJitCode : public VActFunc {
|
|
|
|
|
public:
|
|
|
|
|
explicit VActJitCode(int d, operand_type type, size_t code_size,
|
|
|
|
|
void* code_ptr = nullptr)
|
|
|
|
|
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
|
|
|
|
|
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
|
|
|
|
|
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
|
|
|
|
|
type_ == operand_type::identity)) {
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
|
|
|
|
}
|
|
|
|
|
this->genCode();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char* name() const override {
|
|
|
|
|
std::string base = "VActJitCode";
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
base += "_Relu";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
base += "_Exp";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
base += "_Sigmoid";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
base += "_Tanh";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
base += "_Identity";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
return base.c_str();
|
|
|
|
|
}
|
|
|
|
|
void genCode() override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
int num_;
|
|
|
|
|