|
|
|
@ -215,46 +215,53 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void SeqCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX");
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH");
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto* xx = ctx.Output<LoDTensor>("XX");
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell");
|
|
|
|
|
#define INIT_VEC_FUNC \
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
|
|
|
|
|
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
|
|
|
|
|
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
|
|
|
|
|
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) { \
|
|
|
|
|
math::VecActivations<T, platform::jit::avx> act_functor; \
|
|
|
|
|
act_gate = act_functor(act_gate_str); \
|
|
|
|
|
act_cell = act_functor(act_cell_str); \
|
|
|
|
|
act_cand = act_functor(act_cand_str); \
|
|
|
|
|
} else { \
|
|
|
|
|
math::VecActivations<T, platform::jit::isa_any> act_functor; \
|
|
|
|
|
act_gate = act_functor(act_gate_str); \
|
|
|
|
|
act_cell = act_functor(act_cell_str); \
|
|
|
|
|
act_cand = act_functor(act_cand_str); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INIT_BASE_INPUT_OUTPUT \
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X"); \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* xx = ctx.Output<LoDTensor>("XX"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse");
|
|
|
|
|
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
|
|
|
|
|
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
|
|
|
|
|
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
|
|
|
|
|
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) {
|
|
|
|
|
math::VecActivations<T, platform::jit::avx> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
} else {
|
|
|
|
|
math::VecActivations<T, platform::jit::isa_any> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
}
|
|
|
|
|
#define INIT_BASE_SIZES \
|
|
|
|
|
auto x_dims = x->dims(); /* T x M*/ \
|
|
|
|
|
auto wh_dims = wh->dims(); /* D x 4D*/ \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D2 = D * 2; \
|
|
|
|
|
const int D3 = D * 3; \
|
|
|
|
|
const int D4 = wh_dims[1];
|
|
|
|
|
|
|
|
|
|
void SeqCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
INIT_BASE_SIZES
|
|
|
|
|
INIT_VEC_FUNC
|
|
|
|
|
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
auto wh_dims = wh->dims(); // D x 4D
|
|
|
|
|
const int total_T = x_dims[0];
|
|
|
|
|
const int N = x_lod[0].size() - 1; // batch size
|
|
|
|
|
const int M = x_dims[1]; // x frame size
|
|
|
|
|
const int D = wh_dims[0];
|
|
|
|
|
const int D2 = D * 2;
|
|
|
|
|
const int D3 = D * 3;
|
|
|
|
|
const int D4 = wh_dims[1];
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : NULL;
|
|
|
|
@ -343,52 +350,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX");
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH");
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
|
|
|
|
|
auto* xx = ctx.Output<LoDTensor>("XX");
|
|
|
|
|
INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
if (x->lod()[0].size() == 2) { // batch size == 1
|
|
|
|
|
SeqCompute(ctx);
|
|
|
|
|
}
|
|
|
|
|
INIT_BASE_SIZES
|
|
|
|
|
INIT_VEC_FUNC
|
|
|
|
|
|
|
|
|
|
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
|
|
|
|
|
auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
|
|
|
|
|
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
|
|
|
|
|
auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
|
|
|
|
|
auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell");
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse");
|
|
|
|
|
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
|
|
|
|
|
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
|
|
|
|
|
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
|
|
|
|
|
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) {
|
|
|
|
|
math::VecActivations<T, platform::jit::avx> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
} else {
|
|
|
|
|
math::VecActivations<T, platform::jit::isa_any> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
auto wh_dims = wh->dims(); // D x 4D
|
|
|
|
|
|
|
|
|
|
// auto x_lod = x->lod();
|
|
|
|
|
// const int N = x_lod[0].size() - 1; // batch size
|
|
|
|
|
// if (N == 1) {
|
|
|
|
|
// SeqCompute(ctx);
|
|
|
|
|
// }
|
|
|
|
|
const int M = x_dims[1];
|
|
|
|
|
const int D = wh_dims[0];
|
|
|
|
|
const int D2 = D * 2;
|
|
|
|
|
const int D3 = D * 3;
|
|
|
|
|
const int D4 = wh_dims[1];
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* wx_data = wx->data<T>();
|
|
|
|
@ -485,16 +458,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
// W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
act_gate(D3, cur_in_data + D, cur_in_data + D);
|
|
|
|
|
act_cand(D, cur_in_data, cur_in_data);
|
|
|
|
|
|
|
|
|
|
// a = forget * prev_cell
|
|
|
|
|
blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2);
|
|
|
|
|
|
|
|
|
|
// b = input * tilde
|
|
|
|
|
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
|
|
|
|
|
|
|
|
|
|
// cell out= a+b
|
|
|
|
|
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
|
|
|
|
|
|
|
|
|
|
// hidden out= act_state(cellout) * outgate
|
|
|
|
|
act_cell(D, cur_c_out_data, cur_in_data + D2);
|
|
|
|
|
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
|
|
|
|
@ -526,6 +495,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
BatchCompute(ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#undef INIT_BASE_SIZES
|
|
|
|
|
#undef INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
#undef INIT_VEC_FUNC
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|