|
|
|
@ -302,6 +302,34 @@ class VActJitCode : public JitCode {
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
|
|
|
|
|
// use 15
|
|
|
|
|
JMM zero = JMM(15);
|
|
|
|
|
if (type_ == operand_type::relu) {
|
|
|
|
|
vxorps(zero, zero, zero);
|
|
|
|
|
}
|
|
|
|
|
switch (type) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_jmm<JMM>(dst, src, zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
// throw error
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
int num_;
|
|
|
|
|
operand_type type_;
|
|
|
|
@ -386,44 +414,94 @@ class LSTMJitCode : public VActJitCode {
|
|
|
|
|
operand_type act_cand_;
|
|
|
|
|
operand_type act_cell_;
|
|
|
|
|
reg64_t param1{abi_param1};
|
|
|
|
|
|
|
|
|
|
xmm_t xmm_src = xmm_t(0);
|
|
|
|
|
xmm_t xmm_c = xmm_t(1);
|
|
|
|
|
xmm_t xmm_i = xmm_t(2);
|
|
|
|
|
xmm_t xmm_f = xmm_t(3);
|
|
|
|
|
xmm_t xmm_i = xmm_t(6);
|
|
|
|
|
xmm_t xmm_f = xmm_t(7);
|
|
|
|
|
|
|
|
|
|
ymm_t ymm_src = ymm_t(0);
|
|
|
|
|
ymm_t ymm_c = ymm_t(1);
|
|
|
|
|
ymm_t ymm_i = ymm_t(2);
|
|
|
|
|
ymm_t ymm_f = ymm_t(3);
|
|
|
|
|
ymm_t ymm_c = ymm_t(1); // 2~5 for act
|
|
|
|
|
ymm_t ymm_i = ymm_t(6);
|
|
|
|
|
ymm_t ymm_f = ymm_t(7);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
|
|
|
|
|
// use 15
|
|
|
|
|
JMM zero = JMM(15);
|
|
|
|
|
if (type_ == operand_type::relu) {
|
|
|
|
|
vxorps(zero, zero, zero);
|
|
|
|
|
}
|
|
|
|
|
switch (type) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_jmm<JMM>(dst, src, zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
// throw error
|
|
|
|
|
break;
|
|
|
|
|
class GRUJitCode : public VActJitCode {
|
|
|
|
|
public:
|
|
|
|
|
const char* name() const override {
|
|
|
|
|
std::string base = "GRUJitCode";
|
|
|
|
|
if (id_ == 0) {
|
|
|
|
|
base += "_H1";
|
|
|
|
|
} else if (id_ == 1) {
|
|
|
|
|
base += "_HtPart1";
|
|
|
|
|
} else if (id_ == 2) {
|
|
|
|
|
base += "_HtPart2";
|
|
|
|
|
}
|
|
|
|
|
auto AddTypeStr = [&](operand_type type) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
base += "_Relu";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
base += "_Exp";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
base += "_Sigmoid";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
base += "_Tanh";
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
base += "_Identity";
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
AddTypeStr(act_gate_);
|
|
|
|
|
AddTypeStr(act_cand_);
|
|
|
|
|
return base.c_str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
explicit GRUJitCode(int id, const gru_attr_t& attr,
|
|
|
|
|
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
|
|
|
|
|
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
|
|
|
|
|
code_ptr),
|
|
|
|
|
id_(id) {
|
|
|
|
|
auto typeExchange = [](const std::string& type) -> gen::operand_type {
|
|
|
|
|
if (type == "sigmoid") {
|
|
|
|
|
return operand_type::sigmoid;
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
return operand_type::relu;
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
return operand_type::tanh;
|
|
|
|
|
} else if (type == "identity" || type == "") {
|
|
|
|
|
return operand_type::identity;
|
|
|
|
|
} // else throw error
|
|
|
|
|
return operand_type::identity;
|
|
|
|
|
};
|
|
|
|
|
num_ = attr.d;
|
|
|
|
|
act_gate_ = typeExchange(attr.act_gate);
|
|
|
|
|
act_cand_ = typeExchange(attr.act_cand);
|
|
|
|
|
}
|
|
|
|
|
static bool init(int d);
|
|
|
|
|
void generate() override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
int id_;
|
|
|
|
|
int num_;
|
|
|
|
|
operand_type act_gate_;
|
|
|
|
|
operand_type act_cand_;
|
|
|
|
|
reg64_t param1{abi_param1};
|
|
|
|
|
|
|
|
|
|
xmm_t xmm_src = xmm_t(0);
|
|
|
|
|
xmm_t xmm_c = xmm_t(1);
|
|
|
|
|
xmm_t xmm_i = xmm_t(6);
|
|
|
|
|
xmm_t xmm_f = xmm_t(7);
|
|
|
|
|
|
|
|
|
|
ymm_t ymm_src = ymm_t(0);
|
|
|
|
|
ymm_t ymm_c = ymm_t(1);
|
|
|
|
|
ymm_t ymm_i = ymm_t(6);
|
|
|
|
|
ymm_t ymm_f = ymm_t(7);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|