From 6be273cbdbf3fe46fe5a4b4af787977b9bd59929 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Fri, 24 Aug 2018 22:40:38 +0800 Subject: [PATCH 1/5] add seq mode lstm --- paddle/fluid/operators/fusion_lstm_op.cc | 52 ++++++++++++++++++++---- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 3888333ec5..870292827d 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,10 +15,14 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include <string> #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/cpu_info.h" + +DEFINE_bool(seq_mode, true, "Use sequence mode"); namespace paddle { namespace operators { @@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + int xx_width; + if (FLAGS_seq_mode) { + xx_width = wx_dims[1]; + } else { + xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } @@ -205,10 +214,34 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template <typename DeviceContext, typename T> +template <typename T> class FuisonLSTMKernel : public framework::OpKernel<T> { public: - void Compute(const framework::ExecutionContext& ctx) const override { + void SeqCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* x = ctx.Input<LoDTensor>("X"); + auto* wx = ctx.Input<Tensor>("WeightX"); + auto* wh = ctx.Input<Tensor>("WeightH"); + auto* bias = ctx.Input<Tensor>("Bias"); + + auto* xx = ctx.Output<LoDTensor>("XX"); + + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int M = x_dims[1]; // x frame size + const int D4 = wh_dims[1]; + + const T* x_data = x->data<T>(); + const T* wx_data = wx->data<T>(); + T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); + + auto blas = math::GetBlas<DeviceContext, T>(ctx); + math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data, + xx_data, bias->data<T>()); + } + + void BatchCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = platform::CPUDeviceContext; auto* x = ctx.Input<LoDTensor>("X"); auto* wx = ctx.Input<Tensor>("WeightX"); auto* wh = ctx.Input<Tensor>("WeightH"); @@ -339,6 +372,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { // restore the output cell state in LoDTensor from the batch cell to_seq(dev_ctx, batch_cell, cell_out); } + void Compute(const framework::ExecutionContext& ctx) const override { + if (FLAGS_seq_mode) { + SeqCompute(ctx); + } else { + BatchCompute(ctx); + } + } }; } // namespace operators @@ -348,7 +388,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); -REGISTER_OP_CPU_KERNEL( - fusion_lstm, - ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>, - ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>); +REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>, + ops::FuisonLSTMKernel<double>); From 607c41952e78d8c5d489a75590204f802d392ee5 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sun, 26 Aug 2018 16:10:45 +0800 Subject: [PATCH 2/5] compute gates --- paddle/fluid/operators/fusion_lstm_op.cc | 87 +++++++++++++++++++++++- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 870292827d..604c6f1839 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -220,24 +220,105 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input<LoDTensor>("X"); + auto* h0 = ctx.Input<Tensor>("H0"); + auto* c0 = ctx.Input<Tensor>("C0"); auto* wx = ctx.Input<Tensor>("WeightX"); auto* wh = ctx.Input<Tensor>("WeightH"); auto* bias = ctx.Input<Tensor>("Bias"); auto* xx = ctx.Output<LoDTensor>("XX"); + auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); + auto* cell_out = ctx.Output<LoDTensor>("Cell"); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D - const int M = x_dims[1]; // x frame size + auto x_lod = x->lod(); + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int N = x_lod[0].size() - 1; // batch size + const int M = x_dims[1]; // x frame size + const int D = wh_dims[0]; + const int D2 = D * 2; + const int D3 = D * 3; const int D4 = wh_dims[1]; const T* x_data = x->data<T>(); + const T* h0_data = h0 ? h0->data<T>() : NULL; + const T* c0_data = c0 ? c0->data<T>() : NULL; const T* wx_data = wx->data<T>(); + const T* wh_data = wh->data<T>(); T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); + T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace()); auto blas = math::GetBlas<DeviceContext, T>(ctx); math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>()); + + for (int i = 0; i < N; ++i) { + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + i * D; + prev_cell_data = c0_data + i * D; + } else { + // W_ch, W_ih, W_fh, W_oh + // actgate + math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D); + // ch gate + math::vec_tanh<T>(D, xx_data, xx_data); + // cell out= input*tilde + blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + // hidden out= act_state(cellout) * outgate + // act state + math::vec_tanh<T>(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + tstart = 1; + + // move offset + xx_data = xx_data + D4; + hidden_out_data = hidden_out_data + D; + cell_out_data = cell_out_data + D; + } + for (int step = tstart; step < seq_len; ++step) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1), + prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data, + D4); + + // W_ch, W_ih, W_fh, W_oh + // actgate + math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D); + // ch gate + math::vec_tanh<T>(D, xx_data, xx_data); + + // a = forget * prev_cell + blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); + + // b = input * tilde + blas.VMUL(D, xx_data, xx_data + D, xx_data + D); + + // cell out= a+b + blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + + // hidden out= act_state(cellout) * outgate + // act state + math::vec_tanh<T>(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + + // move offset + xx_data = xx_data + D4; + hidden_out_data = hidden_out_data + D; + cell_out_data = cell_out_data + D; + } + } } void BatchCompute(const framework::ExecutionContext& ctx) const { From 4b28fab8c94863d5ff24ce4c59ff31bb5d06b4ee Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sun, 26 Aug 2018 18:24:00 +0800 Subject: [PATCH 3/5] enable more acts --- paddle/fluid/operators/fusion_lstm_op.cc | 34 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 2 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 604c6f1839..97852e2928 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto* cell_out = ctx.Output<LoDTensor>("Cell"); + std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; + auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); + auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); + auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations<T, platform::jit::avx> act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations<T, platform::jit::isa_any> act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } + auto x_lod = x->lod(); auto x_dims = x->dims(); // T x M auto wh_dims = wh->dims(); // D x 4D @@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { prev_cell_data = c0_data + i * D; } else { // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh<T>(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // cell out= input*tilde blas.VMUL(D, xx_data, xx_data + D, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh<T>(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { D4); // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh<T>(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // a = forget * prev_cell blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); @@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh<T>(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 9d8bef677f..d807f0a8b6 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -45,7 +45,7 @@ def fusion_lstm( class TestLstmOp(OpTest): def set_argument(self): - self.lod = [[2, 3, 2]] + pass def setUp(self): self.op_type = 'fusion_lstm' From 1777cd09f652e18c85a5017058cd29c4794446fa Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sun, 26 Aug 2018 18:42:20 +0800 Subject: [PATCH 4/5] refine fusion lstm op test --- .../tests/unittests/test_fusion_lstm_op.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index d807f0a8b6..19f22fc7bd 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -43,13 +43,13 @@ def fusion_lstm( act_cell, act_cand) -class TestLstmOp(OpTest): - def set_argument(self): +class TestFusionLSTMOp(OpTest): + def set_conf(self): pass def setUp(self): self.op_type = 'fusion_lstm' - self.lod = [[2, 3, 2]] + self.lod = [[2, 3, 5, 4]] self.M = 8 self.D = 16 self.has_initial_state = False @@ -58,33 +58,33 @@ class TestLstmOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' self.use_peepholes = False - self.set_argument() + self.set_conf() T = sum(self.lod[0]) bs = len(self.lod[0]) - x = np.random.normal(size=(T, self.M)).astype('float64') + x = np.random.normal(size=(T, self.M)).astype('float32') if self.has_initial_state: - h0 = np.random.normal(size=(bs, self.D)).astype('float64') - c0 = np.random.normal(size=(bs, self.D)).astype('float64') + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') else: - h0 = np.zeros((bs, self.D)).astype('float64') - c0 = np.zeros((bs, self.D)).astype('float64') + h0 = np.zeros((bs, self.D)).astype('float32') + c0 = np.zeros((bs, self.D)).astype('float32') - wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') + wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') if self.use_peepholes: - b = np.random.normal(size=(1, 7 * self.D)).astype('float64') + b = np.random.normal(size=(1, 7 * self.D)).astype('float32') else: - b = np.random.normal(size=(1, 4 * self.D)).astype('float64') + b = np.random.normal(size=(1, 4 * self.D)).astype('float32') w_b = np.copy(b[:, 0:4 * self.D]) w_c = b[:, 4 * self.D:] if self.use_peepholes else None # this is the weight of fc - wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64') + wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') # this is the bias of fc # and it should be manually added into the bias of this fusion LSTM - bx = np.random.normal(size=(1, 4 * self.D)).astype('float64') + bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') b[0, 0:4 * self.D] += bx[0, :] h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, self.is_reverse, ACTIVATION[self.act_gate], @@ -114,35 +114,44 @@ class TestLstmOp(OpTest): } def test_check_output(self): - self.check_output(atol=1e-8) + self.check_output() -class TestLstmOpInitReverse(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpInit(TestFusionLSTMOp): + def set_conf(self): self.has_initial_state = True - self.is_reverse = True -class TestLstmOpMD1(TestLstmOp): - def set_argument(self): +# class TestFusionLSTMOpReverse(TestFusionLSTMOp): +# def set_conf(self): +# self.is_reverse = True + +# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): +# def set_conf(self): +# self.has_initial_state = True +# self.is_reverse = True + + +class TestFusionLSTMOpMD1(TestFusionLSTMOp): + def set_conf(self): self.M = 36 self.D = 8 -class TestLstmOpMD2(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD2(TestFusionLSTMOp): + def set_conf(self): self.M = 8 self.D = 8 -class TestLstmOpMD3(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD3(TestFusionLSTMOp): + def set_conf(self): self.M = 15 self.D = 3 -class TestLstmOpBS1(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpBS1(TestFusionLSTMOp): + def set_conf(self): self.lod = [[3]] self.D = 16 From e61cf3214da019ca1de1fb68ae143928877b4e62 Mon Sep 17 00:00:00 2001 From: tensor-tang <tangjian03@baidu.com> Date: Sun, 26 Aug 2018 21:00:56 +0800 Subject: [PATCH 5/5] complete reverse seq --- paddle/fluid/operators/fusion_lstm_op.cc | 41 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 17 ++++---- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 97852e2928..e4e4ac8e33 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -229,6 +229,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { auto* xx = ctx.Output<LoDTensor>("XX"); auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto* cell_out = ctx.Output<LoDTensor>("Cell"); + bool is_reverse = ctx.Attr<bool>("is_reverse"); std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); @@ -247,8 +248,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { } auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; // batch size const int M = x_dims[1]; // x frame size const int D = wh_dims[0]; @@ -266,17 +268,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace()); auto blas = math::GetBlas<DeviceContext, T>(ctx); - math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data, + math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>()); + int xx_offset = D4; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 4; + hidden_out_data = hidden_out_data + offset; + cell_out_data = cell_out_data + offset; + xx_offset = -D4; + gate_offset = -D; + } + + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + cell_out_data = cell_out_data + gate_offset; + }; for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_cell_data = NULL; const T* prev_hidden_data = NULL; int tstart = 0; if (h0_data) { - prev_hidden_data = h0_data + i * D; - prev_cell_data = c0_data + i * D; + prev_hidden_data = h0_data + bid * D; + prev_cell_data = c0_data + bid * D; } else { // W_ch, W_ih, W_fh, W_oh act_gate(D3, xx_data + D, xx_data + D); @@ -292,10 +311,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { prev_cell_data = cell_out_data; tstart = 1; - // move offset - xx_data = xx_data + D4; - hidden_out_data = hidden_out_data + D; - cell_out_data = cell_out_data + D; + move_step(); } for (int step = tstart; step < seq_len; ++step) { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1), @@ -323,10 +339,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { prev_hidden_data = hidden_out_data; prev_cell_data = cell_out_data; - // move offset - xx_data = xx_data + D4; - hidden_out_data = hidden_out_data + D; - cell_out_data = cell_out_data + D; + move_step(); } } } diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 19f22fc7bd..5805bdf461 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -122,14 +122,15 @@ class TestFusionLSTMOpInit(TestFusionLSTMOp): self.has_initial_state = True -# class TestFusionLSTMOpReverse(TestFusionLSTMOp): -# def set_conf(self): -# self.is_reverse = True - -# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): -# def set_conf(self): -# self.has_initial_state = True -# self.is_reverse = True +class TestFusionLSTMOpReverse(TestFusionLSTMOp): + def set_conf(self): + self.is_reverse = True + + +class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + self.is_reverse = True class TestFusionLSTMOpMD1(TestFusionLSTMOp):