From 8728b3cce24c69f76167d843b9bb667027110c56 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 12 Oct 2017 11:30:44 +0800 Subject: [PATCH 1/9] Add LSTM Operators. --- paddle/operators/lstm_op.cc | 185 ++++++++++++++++++++++++ paddle/operators/lstm_op.h | 38 +++++ paddle/operators/lstm_unit_op.h | 1 - paddle/operators/math/cross_entropy.cu | 2 - paddle/operators/math/sequence2batch.cc | 26 ++++ paddle/operators/math/sequence2batch.cu | 26 ++++ paddle/operators/math/sequence2batch.h | 113 +++++++++++++++ 7 files changed, 388 insertions(+), 3 deletions(-) create mode 100644 paddle/operators/lstm_op.cc create mode 100644 paddle/operators/lstm_op.h create mode 100644 paddle/operators/math/sequence2batch.cc create mode 100644 paddle/operators/math/sequence2batch.cu create mode 100644 paddle/operators/math/sequence2batch.h diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc new file mode 100644 index 0000000000..6233e12923 --- /dev/null +++ b/paddle/operators/lstm_op.cc @@ -0,0 +1,185 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lstm_unit_op.h" + +namespace paddle { +namespace operators { + +class LSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("H"), + "Output(Cell) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + ctx->SetOutputDim("Hidden", x_dims); + ctx->SetOutputDim("Cell", x_dims); + ctx->ShareLoD("Input", "Hidden"); + ctx->ShareLoD("Input", "Cell"); + } +}; + +class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTenosr is a matrix with shape (T X D), where, T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size, D is the hidden size."); + AddInput("C0", + "(Tensor, optional) the initial cell state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time"); + AddInput("Weight", + "(Tensor) the learnable hidden-hidden weights." + " - The shape is (D x 4*D), where D is the hidden size. " + " - Weight = {W_ih, W_fh, W_ch, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights, which contains two parts: " + "input-hidden bias weight and peephole connections weight if " + "seting `use_peepholes` True. " + "1. `use_peepholes = False` " + " - The shape is (1 x 4*D). " + " - Bias = {b_i, b_f, b_c, b_o}." + "2. `use_peepholes = True` " + " - The shape is (1 x 7*D). " + " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); + AddOutput("Hidden", + "(LoDTensor) the hidden state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddAttr("use_peepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr( + "gate_activation", + "(string, defalut: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by defalut.") + .SetDefault("sigmoid"); + AddAttr("cell_activation", + "(string, defalut: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh"); + AddAttr("candidate_activation", + "(string, defalut: tanh)" + "The activation for candidate hidden state, " + "`tanh` by defalut.") + .SetDefault("tanh"); + AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator + +The defalut implementation is diagonal/peephole connection [1], the formula is +as follows + + i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) + + f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) + + \tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + + o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) + + c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t} + + h_t = o_t ⊙ act_h(c_t) + +where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix +of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$ +are diagonal weight matrices for peephole connections. In our implenmention, +We use vectors to reprenset these diagonal weight matrices. The b terms +denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$ +is the non-line actications, such as logistic sigmoid function, and +\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate, +output gate and cell activation vectors, all of which are the same size as +the cell output activation vector \f$h\f$. + +The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$ +are the cell input and cell output activation functions, `tanh` is usually +used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, +which is computed based on the current input and the previous hidden state. + +Set `use_peepholes` False to disable peephole connection [2]. The formula +is omitted here. + +@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ +operations on the input x_{t} were NOT included in this operator. The +users can choose to use fully-connect operator before LSTM operator. + +[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory +recurrent neural network architectures for large scale acoustic modeling. +INTERSPEECH, 2014. + +[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory. +Neural Computation, 9(8):1735-1780, 1997. + +)DOC"); + } +}; + +class LSTMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(Hidden@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), + "Input(Cell@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("Weight"), + ctx->GetInputDim("Weight")); + ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp); +REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_CPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h new file mode 100644 index 0000000000..6e77cadead --- /dev/null +++ b/paddle/operators/lstm_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; + +template +class LSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +template +class LSTMGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index a0ff498c1d..625b1852c2 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -19,7 +19,6 @@ namespace paddle { namespace operators { -using framework::LoDTensor; using framework::Tensor; template diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index 367190e6b0..db878129d6 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -22,8 +22,6 @@ namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc new file mode 100644 index 0000000000..c29baaae08 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensor2Functor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu new file mode 100644 index 0000000000..5afb87e4a4 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensor2Functor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h new file mode 100644 index 0000000000..6ee870cf78 --- /dev/null +++ b/paddle/operators/math/sequence2batch.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace operators { +namespace math { + +template +class LoDTensor2BatchFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& lod_tensor, + framework::LoDTensor& batch, const bool is_reverse) const { + auto lods = lod_tensor->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + auto lod = lods[0]; + + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seqIdx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size(); ++seq_id) { + int length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), + [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + + // calculate the start position of each batch + // (numBatch equal the maxLength of sequences) + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // num_batch = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + + // The batch number represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + auto batch_lods = batch->lod(); + if (!batch_lods) { + batch_lods->resize(2); + } + // batch_lods[0] is the start positions for batch LoDTensor + int num_batch = (size_t)seq_info[0].length; + batch_lods[0]->resize(num_batch + 1); + // batch_lods[1] is the raw index in the input LoDTensor + auto dims = lod_tensor->dims(); + batch_lods[1]->resize(dims[0]); + + auto* batch_starts = batch_lods[0].data(); + auto* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (size_t n = 0; n < num_batch; n++) { + int batch_id = batch_starts[n]; + for (size_t i = 0; i < seq_info.size(); ++i) { + size_t seq_len = seq_info[i].length; + int start = seq_info[i].start; + if (n < seq_len) { + if (!is_reverse) { + seq2batch_idx[batch_id] = start + n; + } else { + seq2batch_idx[batch_id] = start + seq_len - 1 - n; + } + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = batch_id; + } + } +} + +template +class Batch2LoDTensor2Functor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& batch, + framework::LoDTensor& lod_tensor, + const bool is_reverse) const; + +} // namespace math +} // namespace operators +} // namespace paddle From 3cace73701a052c6593f6cf9151be14c3874f2e8 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 16 Oct 2017 13:23:08 +0800 Subject: [PATCH 2/9] Add lstm implementation. --- paddle/operators/lstm_op.cc | 54 +++- paddle/operators/lstm_op.h | 35 +- .../math/detail/hl_activation_functions.h | 64 ++++ .../operators/math/detail/hl_avx_functions.cc | 68 ++++ .../operators/math/detail/hl_avx_functions.h | 32 ++ .../operators/math/detail/hl_cpu_functions.cc | 44 +++ paddle/operators/math/detail/hl_functions.h | 63 ++++ .../operators/math/detail/hl_gpu_functions.h | 80 +++++ .../operators/math/detail/lstm_cpu_kernel.h | 306 ++++++++++++++++++ .../operators/math/detail/lstm_gpu_kernel.h | 244 ++++++++++++++ paddle/operators/math/detail/lstm_kernel.h | 138 ++++++++ paddle/operators/math/lstm_compute.cc | 73 +++++ paddle/operators/math/lstm_compute.cu | 73 +++++ paddle/operators/math/lstm_compute.h | 87 +++++ paddle/operators/math/sequence2batch.cc | 31 ++ paddle/operators/math/sequence2batch.cu | 47 +++ paddle/operators/math/sequence2batch.h | 19 +- 17 files changed, 1436 insertions(+), 22 deletions(-) create mode 100644 paddle/operators/math/detail/hl_activation_functions.h create mode 100644 paddle/operators/math/detail/hl_avx_functions.cc create mode 100644 paddle/operators/math/detail/hl_avx_functions.h create mode 100644 paddle/operators/math/detail/hl_cpu_functions.cc create mode 100644 paddle/operators/math/detail/hl_functions.h create mode 100644 paddle/operators/math/detail/hl_gpu_functions.h create mode 100644 paddle/operators/math/detail/lstm_cpu_kernel.h create mode 100644 paddle/operators/math/detail/lstm_gpu_kernel.h create mode 100644 paddle/operators/math/detail/lstm_kernel.h create mode 100644 paddle/operators/math/lstm_compute.cc create mode 100644 paddle/operators/math/lstm_compute.cu create mode 100644 paddle/operators/math/lstm_compute.h diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 6233e12923..1803aa1e44 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -1,18 +1,18 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "paddle/operators/lstm_unit_op.h" +#include "paddle/operators/lstm_op.h" namespace paddle { namespace operators { @@ -44,8 +44,36 @@ class LSTMOp : public framework::OperatorWithKernel { "should be the same."); } + int frame_size = x_dims[1]; + auto w_dims = ctx->GetInputDim("Weight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "The rank of Input(Weight) should be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], frame_size, + "The first dimension of Input(Weight) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, + "The second dimension of Input(Weight) " + "should be 4 * %d.", + frame_size); + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + if (ctx->Attrs().Get("use_peepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if diable peepholes connection", + frame_size); + } ctx->SetOutputDim("Hidden", x_dims); ctx->SetOutputDim("Cell", x_dims); + ctx->SetOutputDim("Hidden", x_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } @@ -82,6 +110,8 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "2. `use_peepholes = True` " " - The shape is (1 x 7*D). " " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); + AddOutput("Batch", "(LoDTensor) save the reorganized input as batch info. ") + .AsIntermediate(); AddOutput("Hidden", "(LoDTensor) the hidden state lod tensor of LSTM operator. " "The shape and lod is the same with the `Input`."); @@ -92,6 +122,10 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") .SetDefault(true); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(true); AddAttr( "gate_activation", "(string, defalut: sigmoid)" diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 6e77cadead..037f0485a1 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -1,19 +1,18 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once -#include "glog/logging.h" #include "paddle/framework/op_registry.h" namespace paddle { @@ -25,7 +24,21 @@ using framework::Tensor; template class LSTMKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_t = ctx.Input("Input"); + auto* batch_t = ctx.Input("Batch"); + auto* bias_t = ctx.Input("Bias"); + bool is_reverse = ctx.Attr("is_reverse"); + LoDTensor2BatchFunctor to_batch(ctx.device_context(), input_t, + batch_t, is_reverse); + + auto in_dims = input_t->dims(); + int frame_size = in_dims[1]; + + if (bias_t) { + auto b = EigenMatrix::From(*bias); + } + } }; template diff --git a/paddle/operators/math/detail/hl_activation_functions.h b/paddle/operators/math/detail/hl_activation_functions.h new file mode 100644 index 0000000000..d5cf874636 --- /dev/null +++ b/paddle/operators/math/detail/hl_activation_functions.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_ACTIVATION_FUNCTIONS_H_ +#define HL_ACTIVATION_FUNCTIONS_H_ + +#include "hl_functions.h" + +/** + * Active functions: sigmoid, relu, tanh and linear. + */ +#define HPPL_ACTIVE_FUNCTION \ + { hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear } + +namespace hppl { + +/** + * Hppl supports sigmoid, relu, tanh, linear active functions + * for neural networks' forward and backward activation. + */ +template +class Active { + public: + typedef T (*forward)(T); + typedef T (*backward)(T, T); +}; + +#ifdef __NVCC__ +namespace gpu { +static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +} // namespace gpu +#else +namespace cpu { +static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +} // namespace cpu + +#ifdef __AVX__ +namespace avx { +static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION; +static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION; +} // namespace avx +#endif +#endif + +} // namespace hppl + +#endif // HL_ACTIVATION_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/hl_avx_functions.cc new file mode 100644 index 0000000000..70e7d80304 --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" + +namespace hppl { + +extern __m256 exp(__m256 a); + +__m256 relu(const __m256 a) { + __m256 tmp = _mm256_set1_ps(0.0f); + return _mm256_max_ps(a, tmp); +} + +__m256 sigmoid(const __m256 a) { + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 tmp = _mm256_max_ps(a, min); + tmp = _mm256_min_ps(tmp, max); + tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); + tmp = exp(tmp); + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); + return tmp; +} + +__m256 tanh(const __m256 a) { + __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); + __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); + tmp = _mm256_min_ps(tmp, max); + tmp = exp(tmp); + return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), + _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), + _mm256_set1_ps(1.0f)); +} + +__m256 linear(const __m256 a) { return a; } + +__m256 relu(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), + _mm256_set1_ps(1.0f))); +} + +__m256 sigmoid(const __m256 a, const __m256 b) { + return _mm256_mul_ps(_mm256_mul_ps(a, b), + _mm256_sub_ps(_mm256_set1_ps(1.0f), b)); +} + +__m256 tanh(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b))); +} + +__m256 linear(const __m256 a, const __m256 b) { return a; } +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_avx_functions.h b/paddle/operators/math/detail/hl_avx_functions.h new file mode 100644 index 0000000000..35f4eabb4c --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_AVX_FUNCTIONS_H_ +#define HL_AVX_FUNCTIONS_H_ + +#include + +namespace hppl { +__m256 relu(const __m256 a); +__m256 sigmoid(const __m256 a); +__m256 tanh(const __m256 a); +__m256 linear(const __m256 a); + +__m256 relu(const __m256 a, const __m256 b); +__m256 sigmoid(const __m256 a, const __m256 b); +__m256 tanh(const __m256 a, const __m256 b); +__m256 linear(const __m256 a, const __m256 b); +} // namespace hppl + +#endif // HL_AVX_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc new file mode 100644 index 0000000000..b42e11fd90 --- /dev/null +++ b/paddle/operators/math/detail/hl_cpu_functions.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "/paddle/operators/math/detail/hl_functions.h" + +namespace hppl { + +real relu(const real a) { return a > 0.0f ? a : 0.0f; } + +real sigmoid(const real a) { + const real min = SIGMOID_THRESHOLD_MIN; + const real max = SIGMOID_THRESHOLD_MAX; + real tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); +} + +real tanh(const real a) { + real tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +real linear(const real a) { return a; } + +real relu(const real a, const real b) { return a * (b > 0.0f ? 1.0f : 0.0f); } + +real sigmoid(const real a, const real b) { return a * b * (1 - b); } + +real tanh(const real a, const real b) { return a * (1.0f - b * b); } + +real linear(const real a, const real b) { return a; } +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h new file mode 100644 index 0000000000..4eda1adfe9 --- /dev/null +++ b/paddle/operators/math/detail/hl_functions.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_FUNCTIONS_H_ +#define HL_FUNCTIONS_H_ + +/** + * sigmoid threshold maximum + */ +#define SIGMOID_THRESHOLD_MIN -40.0 + +/** + * sigmoid threshold minimum + */ +#define SIGMOID_THRESHOLD_MAX 13.0 + +#ifndef __NVCC__ +namespace hppl { +/* + * forward activation + */ +template +T relu(const T a); +template +T sigmoid(const T a); +template +T tanh(const T a); +template +T linear(const T a); + +/* + * backward activation + */ +template +T relu(const T a, const T b); +template +T sigmoid(const T a, const T b); +template +T tanh(const T a, const T b); +template +T linear(const T a, const T b); +} // namespace hppl + +#ifdef __AVX__ +#include "hl_avx_functions.h" +#endif + +#else +#include "hl_gpu_functions.h" +#endif + +#endif // HL_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h new file mode 100644 index 0000000000..25fa7c409a --- /dev/null +++ b/paddle/operators/math/detail/hl_gpu_functions.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_GPU_FUNCTIONS_CUH_ +#define HL_GPU_FUNCTIONS_CUH_ + +#include "hl_base.h" + +namespace hppl { + +template +__device__ static T relu(const T a) { + return a > 0.0f ? a : 0.0f; +} + +template <> +__device__ static float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return __fdividef(1.0f, 1.0f + __expf(-tmp)); +} + +template <> +__device__ static double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); +} + +template <> +__device__ static float tanh(const float a) { + return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f; +} + +template <> +__device__ static double tanh(const double a) { + return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; +} + +template +__device__ static T linear(const T a) { + return a; +} + +template +__device__ static T relu(const T a, const T b) { + return a * (b > 0.0f ? 1.0f : 0.0f); +} + +template +__device__ static T sigmoid(const T a, const T b) { + return a * b * (1 - b); +} + +template +__device__ static T tanh(const T a, const T b) { + return a * (1.0f - b * b); +} + +template +__device__ static T linear(const T a, const T b) { + return a; +} + +} // namespace hppl + +#endif // HL_GPU_FUNCTIONS_CUH_ diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h new file mode 100644 index 0000000000..a8e78a449d --- /dev/null +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -0,0 +1,306 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/math/lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +#ifndef __NVCC__ + +template +void naive_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI; + T rCheckF; + T rCheckO; + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::cpu::forward[active_node], + hppl::cpu::forward[active_gate], hppl::cpu::forward[active_state]); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + value.stateValue[i] = rState; + value.stateActiveValue[i] = rStateAtv; + value.outputValue[i] = rOut; + } +} + +template +void naive_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI; + T rCheckF; + T rCheckO; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + T *gradIn = grad.gateGrad; + T *gradIg = grad.gateGrad + frameSize; + T *gradFg = grad.gateGrad + frameSize * 2; + T *gradOg = grad.gateGrad + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + rState = value.stateValue[i]; + rStateAtv = value.stateActiveValue[i]; + rOutputGrad = grad.outputGrad[i]; + rStateGrad = grad.stateGrad[i]; + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, hppl::cpu::backward[active_node], + hppl::cpu::backward[active_gate], hppl::cpu::backward[active_state]); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + grad.stateGrad[i] = rStateGrad; + + if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; + } +} + +template +void avx_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rState; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rStateAtv; + __m256 rOut; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node], + hppl::avx::forward[active_gate], hppl::avx::forward[active_state]); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + ((__m256 *)value.stateValue)[i] = rState; + ((__m256 *)value.stateActiveValue)[i] = rStateAtv; + ((__m256 *)value.outputValue)[i] = rOut; + } +#endif +} + +template +void avx_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rGradIn; + __m256 rGradIg; + __m256 rGradFg; + __m256 rGradOg; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rPrevStateGrad; + __m256 rStateGrad; + __m256 rState; + __m256 rStateAtv; + __m256 rOutputGrad; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rCheckIGrad; + __m256 rCheckFGrad; + __m256 rCheckOGrad; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + __m256 *gradIn = (__m256 *)grad.gateGrad; + __m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); + __m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); + __m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + rState = ((__m256 *)value.stateValue)[i]; + rStateAtv = ((__m256 *)value.stateActiveValue)[i]; + rOutputGrad = ((__m256 *)grad.outputGrad)[i]; + rStateGrad = ((__m256 *)grad.stateGrad)[i]; + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, hppl::avx::backward[active_node], + hppl::avx::backward[active_gate], hppl::avx::backward[active_state]); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + ((__m256 *)grad.stateGrad)[i] = rStateGrad; + + if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; + if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; + } + if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; + } +#endif +} + +template +void cpu_lstm_forward(Op op, lstm_value value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } +} + +template +void cpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } +} + +#endif + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h new file mode 100644 index 0000000000..8d0274c19d --- /dev/null +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -0,0 +1,244 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/math/detail/lstm_kernel.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmForward(Op op, lstm_value value, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.outputValue += batchIdx * frameSize; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + } + + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::gpu::forward[active_node], + hppl::gpu::forward[active_gate], hppl::gpu::forward[active_state]); + + value.gateValue[frameIdx] = rValueIn; + value.gateValue[frameIdx + frameSize] = rValueIg; + value.gateValue[frameIdx + frameSize * 2] = rValueFg; + value.gateValue[frameIdx + frameSize * 3] = rValueOg; + + value.stateValue[frameIdx] = rState; + value.stateActiveValue[frameIdx] = rStateAtv; + value.outputValue[frameIdx] = rOut; +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad, + int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + grad.gateGrad += batchIdx * frameSize * 4; + grad.stateGrad += batchIdx * frameSize; + grad.outputGrad += batchIdx * frameSize; + } + + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + rState = value.stateValue[frameIdx]; + rStateAtv = value.stateActiveValue[frameIdx]; + rOutputGrad = grad.outputGrad[frameIdx]; + rStateGrad = grad.stateGrad[frameIdx]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, + rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, + hppl::gpu::backward[active_node], hppl::gpu::backward[active_gate], + hppl::gpu::backward[active_state]); + + grad.gateGrad[frameIdx] = rGradIn; + grad.gateGrad[frameIdx + frameSize] = rGradIg; + grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; + grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; + grad.stateGrad[frameIdx] = rStateGrad; + if (grad.prevStateGrad) { + if (isBatch) grad.prevStateGrad += batchIdx * frameSize; + grad.prevStateGrad[frameIdx] = rPrevStateGrad; + } + + if (isBatch) { + if (value.prevStateValue) { + if (grad.checkIgGrad) + paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, + rCheckIGrad); + if (grad.checkFgGrad) + paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, + rCheckFGrad); + } + if (grad.checkOgGrad) + paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); + } else { + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; + } +} + +template +void gpu_lstm_forward(Op op, lstm_value value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + if (batchSize == 1) { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +template +void gpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + if (batchSize == 1) { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h new file mode 100644 index 0000000000..107030f8ba --- /dev/null +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_activation_functions.h" + +#ifdef __CUDA_ARCH__ +#define INLINE __device__ inline +#else +#define INLINE inline +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +namespace forward { + +template +class lstm { + public: + INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &prevState, T &state, T &stateAtv, T &output, + T &checkI, T &checkF, T &checkO, + Active::forward actInput, + Active::forward actGate, + Active::forward actState) { + valueIn = actInput(valueIn); + valueIg = actGate(valueIg + prevState * checkI); + valueFg = actGate(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = actGate(valueOg + state * checkO); + stateAtv = actState(state); + output = valueOg * stateAtv; + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &prevState, __m256 &state, + __m256 &stateAtv, __m256 &output, __m256 &checkI, + __m256 &checkF, __m256 &checkO, + Active<__m256>::forward actInput, + Active<__m256>::forward actGate, + Active<__m256>::forward actState) { + valueIn = actInput(valueIn); + valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); + valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); + state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), + _mm256_mul_ps(prevState, valueFg)); + valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO))); + stateAtv = actState(state); + output = _mm256_mul_ps(valueOg, stateAtv); + } +#endif +#endif +}; + +} // namespace forward + +namespace backward { + +template +class lstm { + public: + INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &gradIn, T &gradIg, T &gradFg, T &gradOg, + T &prevState, T &prevStateGrad, T &state, T &stateGrad, + T &stateAtv, T &outputGrad, T &checkI, T &checkF, + T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, + Active::backward actInput, + Active::backward actGate, + Active::backward actState) { + gradOg = actGate(outputGrad * stateAtv, valueOg); + stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = actInput(stateGrad * valueIg, valueIn); + gradIg = actGate(stateGrad * valueIn, valueIg); + gradFg = actGate(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, + __m256 &gradFg, __m256 &gradOg, __m256 &prevState, + __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, + __m256 &outputGrad, __m256 &checkI, __m256 &checkF, + __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, + __m256 &checkOGrad, Active<__m256>::backward actInput, + Active<__m256>::backward actGate, + Active<__m256>::backward actState) { + gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); + stateGrad = _mm256_add_ps( + actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); + stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); + gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn); + gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg); + gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg); + prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), + _mm256_mul_ps(gradFg, checkF)); + prevStateGrad = + _mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); + checkIGrad = _mm256_mul_ps(gradIg, prevState); + checkFGrad = _mm256_mul_ps(gradFg, prevState); + checkOGrad = _mm256_mul_ps(gradOg, state); + } +#endif +#endif +}; + +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle + +#endif /* HL_LSTM_OPS_CUH_ */ diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc new file mode 100644 index 0000000000..77d317048a --- /dev/null +++ b/paddle/operators/math/lstm_compute.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "LstmCompute.h" +#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(lstm_value value, int frame_size, int batch_size, + std::string gate_act, std::string cell_act, + std::string cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_forward(detail::forward::lstm(), value, frameSize, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(lstm_value value, lstm_grad grad, int frame_size, + int batch_size, std::string gate_act, + std::string cell_act, std::string cand_act) { + for (int b = 0; b < batchSize; b++) { + detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, + frameSize, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + + grad.gateGrad += frameSize * 4; + grad.stateGrad += frameSize; + grad.stateActiveGrad += frameSize; + grad.outputGrad += frameSize; + if (grad.prevStateGrad) { + grad.prevStateGrad += frameSize; + } + } + }; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu new file mode 100644 index 0000000000..a7e23920aa --- /dev/null +++ b/paddle/operators/math/lstm_compute.cu @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "LstmCompute.h" +#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(lstm_value value, int frame_size, int batch_size, + std::string gate_act, std::string cell_act, + std::string cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::gpu_lstm_forward(detail::forward::lstm(), value, frameSize, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(lstm_value value, lstm_grad grad, int frame_size, + int batch_size, std::string gate_act, + std::string cell_act, std::string cand_act) { + for (int b = 0; b < batchSize; b++) { + detail::gpu_lstm_backward(detail::backward::lstm(), value, grad, + frameSize, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + + grad.gateGrad += frameSize * 4; + grad.stateGrad += frameSize; + grad.stateActiveGrad += frameSize; + grad.outputGrad += frameSize; + if (grad.prevStateGrad) { + grad.prevStateGrad += frameSize; + } + } + }; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h new file mode 100644 index 0000000000..2d7fccf1a0 --- /dev/null +++ b/paddle/operators/math/lstm_compute.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/macros.h" + +namespace paddle { +namespace operators { +namespace math { + +typedef enum { + HL_ACTIVATION_SIGMOID = 0, + HL_ACTIVATION_RELU = 1, + HL_ACTIVATION_TANH = 2, + HL_ACTIVATION_LINEAR = 3, + HL_ACTIVATION_END +} activation_mode_t; + +template +struct lstm_value { + real *gateValue; + real *prevStateValue; + real *stateValue; + real *stateActiveValue; + real *outputValue; + real *checkIg; + real *checkFg; + real *checkOg; +}; + +template +struct lstm_grad { + real *gateGrad; + real *prevStateGrad; + real *stateGrad; + real *stateActiveGrad; + real *outputGrad; + real *checkIgGrad; + real *checkFgGrad; + real *checkOgGrad; +}; + +activation_mode_t ActiveType(const std::string &type) { + if (type == "sigmoid") { + return HL_ACTIVATION_SIGMOID; + } else if (type == "relu") { + return HL_ACTIVATION_RELU; + } else if (type == "tanh") { + return HL_ACTIVATION_TANH; + } else if (type == "linear" || type == "") { + return HL_ACTIVATION_LINEAR; + } else { + PADDLE_THROW("Do not support activation type."); + } +} + +template +class LstmUnitFunctor { + public: + static void compute(lstm_value value, int frame_size, int batch_size, + std::string gate_act, std::string cell_act, + std::string cand_act); +}; + +template +class LstmUnitGradFunctor { + public: + static void compute(lstm_value value, lstm_grad grad, int frame_size, + int batch_size, std::string gate_act, + std::string cell_act, std::string cand_act); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index c29baaae08..f4da949d4e 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -18,6 +18,37 @@ namespace paddle { namespace operators { namespace math { +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& src, const size_t* index, + framework::Tensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2."); + PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + for (int i = 0; i < height; ++i) { + if (is_src_index) { + memcpy(dst_data + i * width, src_data + index[i] * width, + width * sizeof(T)); + } else { + memcpy(dst_data + index[i] * width, src_data + i * width, + width * sizeof(T)); + } + } + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + template class LoDTensor2BatchFunctor; template class Batch2LoDTensor2Functor; diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index 5afb87e4a4..ecd05a30d3 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -18,6 +18,53 @@ namespace paddle { namespace operators { namespace math { +template +__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const int* index, + int height, int width, + const bool is_src_index) { + int idx = threadIdx.x; + int idy = threadIdx.y; + int id = blockIdx.x + idy * GridDimX; + while (id < height) { + int src_idx = is_src_index ? index[id] : id; + int dst_idx = is_src_index ? id : index[id]; + T* src_data = src + src_idx * width; + T* dst_data = dst + dst_idx * width; + for (int i = idx; i < width; i += BlockDimX) { + dst_data[i] = src_data[i]; + } + id += BlockDimY * GridDimX; + } +} + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& src, const size_t* index, + framework::Tensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2."); + PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + + dim3 threads(128, 8); + dim3 grid(8, 1); + auto stream = reinterpret_cast(context); + CopyMatrixRowsKernel<<>>( + src_data, dst_data, index, height, width); + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + template class LoDTensor2BatchFunctor; template class Batch2LoDTensor2Functor; diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 6ee870cf78..e662292a02 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -16,6 +16,19 @@ namespace paddle { namespace operators { namespace math { +template +class CopyMatrixRowsFunctor { + public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const platform::DeviceContext& context, + const framework::Tensor& src, const size_t* index, + framework::Tensor& dst, const bool is_src_index); +}; + template class LoDTensor2BatchFunctor { public: @@ -97,8 +110,11 @@ class LoDTensor2BatchFunctor { } batch_starts[n + 1] = batch_id; } + + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, batch, true); } -} +}; template class Batch2LoDTensor2Functor { @@ -107,6 +123,7 @@ class Batch2LoDTensor2Functor { const framework::LoDTensor& batch, framework::LoDTensor& lod_tensor, const bool is_reverse) const; +}; } // namespace math } // namespace operators From 2a8dbd130d46c949373d12aedcd0ca84f015a0be Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 17 Oct 2017 13:50:22 +0800 Subject: [PATCH 3/9] LSTM Operator forward implementation. --- paddle/framework/CMakeLists.txt | 4 +- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/lstm_op.cc | 41 +++-- paddle/operators/lstm_op.h | 108 +++++++++++-- paddle/operators/math/CMakeLists.txt | 5 +- .../math/detail/hl_activation_functions.h | 146 ++++++++++++++++-- .../operators/math/detail/hl_cpu_functions.cc | 44 ------ paddle/operators/math/detail/hl_functions.h | 95 ++++++++++-- .../operators/math/detail/hl_gpu_functions.h | 65 ++++---- .../operators/math/detail/lstm_cpu_kernel.h | 46 +++--- .../operators/math/detail/lstm_gpu_kernel.h | 74 +++++---- paddle/operators/math/detail/lstm_kernel.h | 29 ++-- paddle/operators/math/lstm_compute.cc | 52 ++++--- paddle/operators/math/lstm_compute.cu | 63 ++++---- paddle/operators/math/lstm_compute.h | 51 +++--- paddle/operators/math/sequence2batch.cc | 14 +- paddle/operators/math/sequence2batch.cu | 25 +-- paddle/operators/math/sequence2batch.h | 49 ++++-- .../paddle/v2/framework/tests/test_lstm_op.py | 116 ++++++++++++++ 19 files changed, 730 insertions(+), 301 deletions(-) delete mode 100644 paddle/operators/math/detail/hl_cpu_functions.cc create mode 100644 python/paddle/v2/framework/tests/test_lstm_op.py diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index c8d9dac21d..c993189603 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -46,9 +46,9 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope frame set(EXECUTOR_TEST_OP elementwise_add_op gaussian_random_op feed_op fetch_op mul_op sum_op squared_l2_distance_op fill_constant_op sgd_op mean_op) if(WITH_GPU) - nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) + # nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) else() - cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) + # cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) endif() cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 75fcc1cda1..7ce774a285 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -115,7 +115,8 @@ set(DEPS_OPS softmax_with_cross_entropy_op sum_op pool_op - pool_with_index_op) + pool_with_index_op + lstm_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) +op_library(lstm_op DEPS sequence2batch) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 1803aa1e44..7a72a08c50 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -22,12 +22,12 @@ class LSTMOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input(Input) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Output(Hidden) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("H"), + PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); auto x_dims = ctx->GetInputDim("Input"); @@ -60,7 +60,7 @@ class LSTMOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - if (ctx->Attrs().Get("use_peepholes")) { + if (ctx->Attrs().Get("usePeepholes")) { PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, "The second dimension of Input(Bias) should be " "7 * %d if enable peepholes connection", @@ -73,7 +73,7 @@ class LSTMOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Hidden", x_dims); ctx->SetOutputDim("Cell", x_dims); - ctx->SetOutputDim("Hidden", x_dims); + ctx->SetOutputDim("Batch", x_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } @@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(LoDTensor) the first input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " - "this LoDTenosr is a matrix with shape (T X D), where, T is the " + "this LoDTenosr is a matrix with shape (T X 4D), where, T is the " "total time steps in this mini-batch, D is the hidden size."); AddInput("H0", "(Tensor, optional) the initial hidden state is an optional " @@ -103,14 +103,21 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Bias", "(Tensor) the learnable weights, which contains two parts: " "input-hidden bias weight and peephole connections weight if " - "seting `use_peepholes` True. " - "1. `use_peepholes = False` " + "seting `usePeepholes` True. " + "1. `usePeepholes = False` " " - The shape is (1 x 4*D). " " - Bias = {b_i, b_f, b_c, b_o}." - "2. `use_peepholes = True` " + "2. `usePeepholes = True` " " - The shape is (1 x 7*D). " " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); - AddOutput("Batch", "(LoDTensor) save the reorganized input as batch info. ") + AddOutput("BatchGate", + "(LoDTensor) This LoDTensor contains input gate, forget gate " + "and output gate aftern the nonlinear computation. This " + "LoDTensor has the same shape with the reorganized input, which " + "was also be called batch input. The LoD size is 2. The first " + "LoD is the batch offsets and the second LoD contains the " + "indexes, which denote the position of reorganized sequence " + "in the raw input.") .AsIntermediate(); AddOutput("Hidden", "(LoDTensor) the hidden state lod tensor of LSTM operator. " @@ -118,25 +125,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Cell", "(LoDTensor) the cell state lod tensor of LSTM operator. " "The shape and lod is the same with the `Input`."); - AddAttr("use_peepholes", + AddAttr("usePeepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") .SetDefault(true); - AddAttr("is_reverse", + AddAttr("isReverse", "(bool, defalut: False) " "whether to compute reversed LSTM.") - .SetDefault(true); + .SetDefault(false); AddAttr( - "gate_activation", + "gateActivation", "(string, defalut: sigmoid)" "The activation for input gate, forget gate and output " "gate, `sigmoid` by defalut.") .SetDefault("sigmoid"); - AddAttr("cell_activation", + AddAttr("cellActivation", "(string, defalut: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh"); - AddAttr("candidate_activation", + AddAttr("candidateActivation", "(string, defalut: tanh)" "The activation for candidate hidden state, " "`tanh` by defalut.") @@ -173,7 +180,7 @@ are the cell input and cell output activation functions, `tanh` is usually used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, which is computed based on the current input and the previous hidden state. -Set `use_peepholes` False to disable peephole connection [2]. The formula +Set `usePeepholes` False to disable peephole connection [2]. The formula is omitted here. @note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ @@ -196,7 +203,7 @@ class LSTMGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), "Input(Hidden@GRAD) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 037f0485a1..6924cba68f 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -14,30 +14,120 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence2batch.h" namespace paddle { namespace operators { using framework::LoDTensor; using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; template class LSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input_t = ctx.Input("Input"); - auto* batch_t = ctx.Input("Batch"); - auto* bias_t = ctx.Input("Bias"); - bool is_reverse = ctx.Attr("is_reverse"); - LoDTensor2BatchFunctor to_batch(ctx.device_context(), input_t, - batch_t, is_reverse); - - auto in_dims = input_t->dims(); + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* batch_gate = ctx.Output("BatchGate"); + batch_gate->mutable_data(ctx.GetPlace()); + auto* hidden_out = ctx.Output("Hidden"); + hidden_out->mutable_data(ctx.GetPlace()); + auto* cell_out = ctx.Output("Cell"); + cell_out->mutable_data(ctx.GetPlace()); + + // Now the function ShareLoD in InferShape is not implemented. + // So copy LoD here. + ctx.ShareLoD("Input", "Hidden"); + ctx.ShareLoD("Input", "Cell"); + + bool is_reverse = ctx.Attr("isReverse"); + math::LoDTensor2BatchFunctor to_batch; + to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); + + auto in_dims = input->dims(); int frame_size = in_dims[1]; - if (bias_t) { + if (bias) { + Eigen::array extents({{1, 4 * frame_size}}); + Eigen::array offsets({{0, 0}}); auto b = EigenMatrix::From(*bias); + auto gate = EigenMatrix::From(*batch_gate); + gate.device(ctx.GetEigenDevice()) = + gate + + b.slice(offsets, extents) + .reshape(Eigen::array({{1, frame_size * 4}})) + .broadcast( + Eigen::array({{static_cast(in_dims[0]), 1}})); + } + + math::LstmMetaValue lstm_value; + T* bias_data = const_cast(bias->data()); + // the code styple in LstmMetaValue will be updated later. + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.prevStateValue = nullptr; + + framework::LoDTensor batch_out; + batch_out.mutable_data(in_dims, ctx.GetPlace()); + framework::LoDTensor batch_cell; + batch_cell.mutable_data(in_dims, ctx.GetPlace()); + framework::LoDTensor batch_cell_pre_act; + batch_cell_pre_act.mutable_data(in_dims, ctx.GetPlace()); + + auto batch_lod = batch_gate->lod()[0]; + int num_batch = batch_lod.size() - 1; + + auto gate_act = ctx.Attr("gateActivation"); + auto cell_act = ctx.Attr("cellActivation"); + auto cand_act = ctx.Attr("candidateActivation"); + + for (int n = 0; n < num_batch; n++) { + int bstart = batch_lod[n]; + int bend = batch_lod[n + 1]; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor out_t = batch_out.Slice(bstart, bend); + Tensor cell_t = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend); + + int cur_batch_size = bend - bstart; + + if (n != 0) { + int pre_end = batch_lod[n - 1]; + auto pre_hidden_t = batch_out.Slice(pre_end, bstart); + math::matmul(ctx.device_context(), pre_hidden_t, false, + *weight, false, static_cast(1.0), &gate_t, + static_cast(0.0)); + } + // else if : how to pass the state from + // last mini-batch will be supported later + + lstm_value.gateValue = gate_t.data(); + lstm_value.outputValue = out_t.data(); + lstm_value.stateValue = cell_t.data(); + lstm_value.stateActiveValue = cell_pre_act_t.data(); + math::LstmUnitFunctor::compute(ctx.device_context(), lstm_value, + frame_size, cur_batch_size, + gate_act, cell_act, cand_act); + lstm_value.prevStateValue = lstm_value.stateValue; } + + math::Batch2LoDTensorFunctor to_seq; + batch_out.set_lod(batch_gate->lod()); + // restore the output hidden in LoDTensor from the batch hidden + to_seq(ctx.device_context(), batch_out, *hidden_out); + + batch_out.set_lod(batch_gate->lod()); + // restore the output cell state in LoDTensor from the batch cell + to_seq(ctx.device_context(), batch_cell, *cell_out); } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 1a2f623ce7..794ffc3997 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -5,13 +5,16 @@ if(WITH_GPU) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) + nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) + nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) - cc_library(vol2col SRCS vol2col.cc DEPS device_context) + cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) + cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/detail/hl_activation_functions.h b/paddle/operators/math/detail/hl_activation_functions.h index d5cf874636..9d7d9914f0 100644 --- a/paddle/operators/math/detail/hl_activation_functions.h +++ b/paddle/operators/math/detail/hl_activation_functions.h @@ -16,15 +16,30 @@ limitations under the License. */ #define HL_ACTIVATION_FUNCTIONS_H_ #include "hl_functions.h" +#include "paddle/operators/math/lstm_compute.h" /** * Active functions: sigmoid, relu, tanh and linear. */ -#define HPPL_ACTIVE_FUNCTION \ +#define FLOAT_ACTIVE_FUNCTION \ + { \ + hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \ + hppl::typef::linear \ + } + +#define DOUBLE_ACTIVE_FUNCTION \ + { \ + hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \ + hppl::typed::linear \ + } + +#define AVX_ACTIVE_FUNCTION \ { hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear } namespace hppl { +using activation_mode_t = paddle::operators::math::activation_mode_t; + /** * Hppl supports sigmoid, relu, tanh, linear active functions * for neural networks' forward and backward activation. @@ -36,25 +51,134 @@ class Active { typedef T (*backward)(T, T); }; +template +struct ForwardActType; + +template <> +struct ForwardActType { + using type = Active::forward; +}; + +template <> +struct ForwardActType { + using type = Active::forward; +}; + +template +struct BackwardActType; + +template <> +struct BackwardActType { + using type = Active::backward; +}; + +template <> +struct BackwardActType { + using type = Active::backward; +}; + #ifdef __NVCC__ namespace gpu { -static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; -static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; -static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; -static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; + +static __device__ Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; +static __device__ Active::backward backward_d[] = + DOUBLE_ACTIVE_FUNCTION; + +template +struct ForwardAct { + __device__ typename ForwardActType::type operator()( + activation_mode_t type); +}; + +template <> +struct ForwardAct { + __device__ ForwardActType::type operator()(activation_mode_t type) { + return forward[type]; + } +}; + +template <> +struct ForwardAct { + __device__ ForwardActType::type operator()(activation_mode_t type) { + return forward_d[type]; + } +}; + +template +struct BackwardAct { + __device__ typename BackwardActType::type operator()( + activation_mode_t type); +}; + +template <> +struct BackwardAct { + __device__ BackwardActType::type operator()(activation_mode_t type) { + return backward[type]; + } +}; + +template <> +struct BackwardAct { + __device__ BackwardActType::type operator()(activation_mode_t type) { + return backward_d[type]; + } +}; + } // namespace gpu #else namespace cpu { -static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; -static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; -static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; -static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +static Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; +static Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; + +static Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; +static Active::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION; + +template +struct ForwardAct { + typename ForwardActType::type operator()(activation_mode_t type); +}; + +template <> +struct ForwardAct { + ForwardActType::type operator()(activation_mode_t type) { + return forward[type]; + } +}; + +template <> +struct ForwardAct { + ForwardActType::type operator()(activation_mode_t type) { + return forward_d[type]; + } +}; + +template +struct BackwardAct { + typename BackwardActType::type operator()(activation_mode_t type); +}; + +template <> +struct BackwardAct { + BackwardActType::type operator()(activation_mode_t type) { + return backward[type]; + } +}; + +template <> +struct BackwardAct { + BackwardActType::type operator()(activation_mode_t type) { + return backward_d[type]; + } +}; + } // namespace cpu #ifdef __AVX__ namespace avx { -static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION; -static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION; +static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION; +static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION; } // namespace avx #endif #endif diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc deleted file mode 100644 index b42e11fd90..0000000000 --- a/paddle/operators/math/detail/hl_cpu_functions.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "/paddle/operators/math/detail/hl_functions.h" - -namespace hppl { - -real relu(const real a) { return a > 0.0f ? a : 0.0f; } - -real sigmoid(const real a) { - const real min = SIGMOID_THRESHOLD_MIN; - const real max = SIGMOID_THRESHOLD_MAX; - real tmp = (a < min) ? min : ((a > max) ? max : a); - return 1.0 / (1.0 + exp(-tmp)); -} - -real tanh(const real a) { - real tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -real linear(const real a) { return a; } - -real relu(const real a, const real b) { return a * (b > 0.0f ? 1.0f : 0.0f); } - -real sigmoid(const real a, const real b) { return a * b * (1 - b); } - -real tanh(const real a, const real b) { return a * (1.0f - b * b); } - -real linear(const real a, const real b) { return a; } -} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h index 4eda1adfe9..c77c119dfe 100644 --- a/paddle/operators/math/detail/hl_functions.h +++ b/paddle/operators/math/detail/hl_functions.h @@ -25,31 +25,94 @@ limitations under the License. */ */ #define SIGMOID_THRESHOLD_MAX 13.0 +/** + * The maximum input value for exp, used to avoid overflow problem. + * currently only used for tanh function. + */ +#define EXP_MAX_INPUT 40.0 + #ifndef __NVCC__ namespace hppl { +namespace typef { +/* + * forward activation + */ +float relu(const float a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +float linear(const float a) { return a; } + +/* + * backward activation + */ +float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } + +float sigmoid(const float a, const float b) { + return a * b * (static_cast(1) - b); +} + +float tanh(const float a, const float b) { + return a * (static_cast(1) - b * b); +} + +float linear(const float a, const float b) { return a; } +} // namespace typef + +namespace typed { /* * forward activation */ -template -T relu(const T a); -template -T sigmoid(const T a); -template -T tanh(const T a); -template -T linear(const T a); +double relu(const double a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +double linear(const double a) { return a; } /* * backward activation */ -template -T relu(const T a, const T b); -template -T sigmoid(const T a, const T b); -template -T tanh(const T a, const T b); -template -T linear(const T a, const T b); +double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +double sigmoid(const double a, const double b) { + return a * b * (static_cast(1) - b); +} + +double tanh(const double a, const double b) { + return a * (static_cast(1) - b * b); +} + +double linear(const double a, const double b) { return a; } +} // namespace typed + } // namespace hppl #ifdef __AVX__ diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h index 25fa7c409a..eee93dd578 100644 --- a/paddle/operators/math/detail/hl_gpu_functions.h +++ b/paddle/operators/math/detail/hl_gpu_functions.h @@ -18,13 +18,10 @@ limitations under the License. */ #include "hl_base.h" namespace hppl { +namespace typef { -template -__device__ static T relu(const T a) { - return a > 0.0f ? a : 0.0f; -} +__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; } -template <> __device__ static float sigmoid(const float a) { const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -32,7 +29,32 @@ __device__ static float sigmoid(const float a) { return __fdividef(1.0f, 1.0f + __expf(-tmp)); } -template <> +__device__ static float tanh(const float a) { + return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f; +} + +__device__ static float linear(const float a) { return a; } + +__device__ static float relu(const float a, const float b) { + return a * (b > 0.0f ? 1.0f : 0.0f); +} + +__device__ static float sigmoid(const float a, const float b) { + return a * b * (1.0f - b); +} + +__device__ static float tanh(const float a, const float b) { + return a * (1.0f - b * b); +} + +__device__ static float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { + +__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; } + __device__ static double sigmoid(const double a) { const double min = SIGMOID_THRESHOLD_MIN; const double max = SIGMOID_THRESHOLD_MAX; @@ -40,40 +62,27 @@ __device__ static double sigmoid(const double a) { return 1.0 / (1.0 + exp(-tmp)); } -template <> -__device__ static float tanh(const float a) { - return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f; -} - -template <> __device__ static double tanh(const double a) { return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; } -template -__device__ static T linear(const T a) { - return a; -} +__device__ static double linear(const double a) { return a; } -template -__device__ static T relu(const T a, const T b) { - return a * (b > 0.0f ? 1.0f : 0.0f); +__device__ static double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); } -template -__device__ static T sigmoid(const T a, const T b) { +__device__ static double sigmoid(const double a, const double b) { return a * b * (1 - b); } -template -__device__ static T tanh(const T a, const T b) { - return a * (1.0f - b * b); +__device__ static double tanh(const double a, const double b) { + return a * (1.0 - b * b); } -template -__device__ static T linear(const T a, const T b) { - return a; -} +__device__ static double linear(const double a, const double b) { return a; } + +} // namespace typef } // namespace hppl diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index a8e78a449d..74d51d7bc9 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/lstm_compute.h" namespace paddle { @@ -23,7 +25,8 @@ namespace detail { #ifndef __NVCC__ template -void naive_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, +void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, + int frameSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { @@ -57,9 +60,10 @@ void naive_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, rPrevState = value.prevStateValue[i]; } + hppl::cpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, hppl::cpu::forward[active_node], - hppl::cpu::forward[active_gate], hppl::cpu::forward[active_state]); + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); valueIn[i] = rValueIn; valueIg[i] = rValueIg; @@ -72,8 +76,8 @@ void naive_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, } template -void naive_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, - int frameSize, +void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { @@ -123,11 +127,11 @@ void naive_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, rPrevState = value.prevStateValue[i]; } + hppl::cpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, hppl::cpu::backward[active_node], - hppl::cpu::backward[active_gate], hppl::cpu::backward[active_state]); + rCheckOGrad, act(active_node), act(active_gate), act(active_state)); gradIn[i] = rGradIn; gradIg[i] = rGradIg; @@ -144,8 +148,8 @@ void naive_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, } } -template -void avx_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, +template +void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { @@ -195,9 +199,9 @@ void avx_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, #endif } -template -void avx_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, - int frameSize, +template +void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { @@ -271,13 +275,13 @@ void avx_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, } template -void cpu_lstm_forward(Op op, lstm_value value, int frameSize, +void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { - avx_lstm_forward_one_sequence(op, value, frameSize, active_node, - active_gate, active_state); + if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); } else { naive_lstm_forward_one_sequence(op, value, frameSize, active_node, active_gate, active_state); @@ -285,13 +289,13 @@ void cpu_lstm_forward(Op op, lstm_value value, int frameSize, } template -void cpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, - activation_mode_t active_node, +void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, + int frameSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { - avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, - active_gate, active_state); + if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); } else { naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, active_gate, active_state); diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 8d0274c19d..01310a49f8 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/operators/math/detail/lstm_kernel.h" +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/lstm_compute.h" #include "paddle/platform/cuda_helper.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace operators { @@ -27,10 +29,11 @@ namespace detail { * grid(frameBlocks, batchBlocks) */ template -__global__ void KeLstmForward(Op op, lstm_value value, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { +__global__ void KeLstmForward( + Op op, LstmMetaValue value, int frameSize, int batchSize, + typename hppl::ForwardActType::type active_node, + typename hppl::ForwardActType::type active_gate, + typename hppl::ForwardActType::type active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -67,8 +70,7 @@ __global__ void KeLstmForward(Op op, lstm_value value, int frameSize, } op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, hppl::gpu::forward[active_node], - hppl::gpu::forward[active_gate], hppl::gpu::forward[active_state]); + rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -85,11 +87,11 @@ __global__ void KeLstmForward(Op op, lstm_value value, int frameSize, * grid(frameBlocks, batchBlocks) */ template -__global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad, - int frameSize, int batchSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { +__global__ void KeLstmBackward( + Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, + int batchSize, typename hppl::BackwardActType::type active_node, + typename hppl::BackwardActType::type active_gate, + typename hppl::BackwardActType::type active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -143,8 +145,7 @@ __global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - hppl::gpu::backward[active_node], hppl::gpu::backward[active_gate], - hppl::gpu::backward[active_state]); + active_node, active_gate, active_state); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -177,7 +178,8 @@ __global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad, } template -void gpu_lstm_forward(Op op, lstm_value value, int frameSize, int batchSize, +void gpu_lstm_forward(const platform::DeviceContext& context, Op op, + LstmMetaValue value, int frameSize, int batchSize, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { @@ -194,22 +196,30 @@ void gpu_lstm_forward(Op op, lstm_value value, int frameSize, int batchSize, grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } + using type = typename hppl::ForwardActType::type; + hppl::gpu::ForwardAct act; + type act_node = act(active_node); + type act_gate = act(active_gate); + type act_state = act(active_state); + + auto stream = + reinterpret_cast(context).stream(); if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + /* isBatch= */ false><<>>( + op, value, frameSize, batchSize, act_node, act_gate, act_state); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + /* isBatch= */ true><<>>( + op, value, frameSize, batchSize, act_node, act_gate, act_state); } } template -void gpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, - int batchSize, activation_mode_t active_node, +void gpu_lstm_backward(const platform::DeviceContext& context, Op op, + LstmMetaValue value, LstmMetaGrad grad, + int frameSize, int batchSize, + activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { dim3 threads; @@ -225,16 +235,22 @@ void gpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } + using type = typename hppl::BackwardActType::type; + hppl::gpu::BackwardAct act; + type act_node = act(active_node); + type act_gate = act(active_gate); + type act_state = act(active_state); + + auto stream = + reinterpret_cast(context).stream(); if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + /* isBatch= */ false><<>>( + op, value, grad, frameSize, batchSize, act_node, act_gate, act_state); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + /* isBatch= */ true><<>>( + op, value, grad, frameSize, batchSize, act_node, act_gate, act_state); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 107030f8ba..b1e59a4ee8 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_activation_functions.h" +#include "paddle/operators/math/detail/hl_activation_functions.h" #ifdef __CUDA_ARCH__ #define INLINE __device__ inline @@ -33,9 +33,9 @@ class lstm { INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, T &prevState, T &state, T &stateAtv, T &output, T &checkI, T &checkF, T &checkO, - Active::forward actInput, - Active::forward actGate, - Active::forward actState) { + typename hppl::ForwardActType::type actInput, + typename hppl::ForwardActType::type actGate, + typename hppl::ForwardActType::type actState) { valueIn = actInput(valueIn); valueIg = actGate(valueIg + prevState * checkI); valueFg = actGate(valueFg + prevState * checkF); @@ -53,9 +53,9 @@ class lstm { __m256 &valueOg, __m256 &prevState, __m256 &state, __m256 &stateAtv, __m256 &output, __m256 &checkI, __m256 &checkF, __m256 &checkO, - Active<__m256>::forward actInput, - Active<__m256>::forward actGate, - Active<__m256>::forward actState) { + hppl::Active<__m256>::forward actInput, + hppl::Active<__m256>::forward actGate, + hppl::Active<__m256>::forward actState) { valueIn = actInput(valueIn); valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); @@ -81,9 +81,9 @@ class lstm { T &prevState, T &prevStateGrad, T &state, T &stateGrad, T &stateAtv, T &outputGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, - Active::backward actInput, - Active::backward actGate, - Active::backward actState) { + typename hppl::BackwardActType::type actInput, + typename hppl::BackwardActType::type actGate, + typename hppl::BackwardActType::type actState) { gradOg = actGate(outputGrad * stateAtv, valueOg); stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradIn = actInput(stateGrad * valueIg, valueIn); @@ -106,9 +106,10 @@ class lstm { __m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, - __m256 &checkOGrad, Active<__m256>::backward actInput, - Active<__m256>::backward actGate, - Active<__m256>::backward actState) { + __m256 &checkOGrad, + hppl::Active<__m256>::backward actInput, + hppl::Active<__m256>::backward actGate, + hppl::Active<__m256>::backward actState) { gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); stateGrad = _mm256_add_ps( actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); @@ -134,5 +135,3 @@ class lstm { } // namespace math } // namespace operators } // namespace paddle - -#endif /* HL_LSTM_OPS_CUH_ */ diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index 77d317048a..293c9da3a0 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "LstmCompute.h" +#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/detail/lstm_cpu_kernel.h" #include "paddle/operators/math/detail/lstm_kernel.h" @@ -22,19 +22,20 @@ namespace math { template struct LstmUnitFunctor { - static void compute(lstm_value value, int frame_size, int batch_size, + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { for (int b = 0; b < batch_size; b++) { - detail::cpu_lstm_forward(detail::forward::lstm(), value, frameSize, + detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frameSize * 4; - value.stateValue += frameSize; - value.stateActiveValue += frameSize; - value.outputValue += frameSize; + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; if (value.prevStateValue) { - value.prevStateValue += frameSize; + value.prevStateValue += frame_size; } } } @@ -42,31 +43,36 @@ struct LstmUnitFunctor { template struct LstmUnitGradFunctor { - static void compute(lstm_value value, lstm_grad grad, int frame_size, - int batch_size, std::string gate_act, + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { - for (int b = 0; b < batchSize; b++) { + for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, - frameSize, ActiveType(cand_act), + frame_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frameSize * 4; - value.stateValue += frameSize; - value.stateActiveValue += frameSize; - value.outputValue += frameSize; + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; if (value.prevStateValue) { - value.prevStateValue += frameSize; + value.prevStateValue += frame_size; } - grad.gateGrad += frameSize * 4; - grad.stateGrad += frameSize; - grad.stateActiveGrad += frameSize; - grad.outputGrad += frameSize; + grad.gateGrad += frame_size * 4; + grad.stateGrad += frame_size; + grad.stateActiveGrad += frame_size; + grad.outputGrad += frame_size; if (grad.prevStateGrad) { - grad.prevStateGrad += frameSize; + grad.prevStateGrad += frame_size; } } - }; + } +}; + +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index a7e23920aa..aade604b9e 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "LstmCompute.h" -#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_gpu_kernel.h" #include "paddle/operators/math/detail/lstm_kernel.h" +#include "paddle/operators/math/lstm_compute.h" namespace paddle { namespace operators { @@ -22,19 +22,20 @@ namespace math { template struct LstmUnitFunctor { - static void compute(lstm_value value, int frame_size, int batch_size, + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { for (int b = 0; b < batch_size; b++) { - detail::gpu_lstm_forward(detail::forward::lstm(), value, frameSize, - ActiveType(cand_act), ActiveType(gate_act), - ActiveType(cell_act)); - value.gateValue += frameSize * 4; - value.stateValue += frameSize; - value.stateActiveValue += frameSize; - value.outputValue += frameSize; + detail::gpu_lstm_forward(context, detail::forward::lstm(), value, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; if (value.prevStateValue) { - value.prevStateValue += frameSize; + value.prevStateValue += frame_size; } } } @@ -42,31 +43,37 @@ struct LstmUnitFunctor { template struct LstmUnitGradFunctor { - static void compute(lstm_value value, lstm_grad grad, int frame_size, - int batch_size, std::string gate_act, + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { - for (int b = 0; b < batchSize; b++) { - detail::gpu_lstm_backward(detail::backward::lstm(), value, grad, - frameSize, ActiveType(cand_act), - ActiveType(gate_act), ActiveType(cell_act)); + for (int b = 0; b < batch_size; b++) { + detail::gpu_lstm_backward(context, detail::backward::lstm(), value, + grad, frame_size, batch_size, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); - value.gateValue += frameSize * 4; - value.stateValue += frameSize; - value.stateActiveValue += frameSize; - value.outputValue += frameSize; + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; if (value.prevStateValue) { - value.prevStateValue += frameSize; + value.prevStateValue += frame_size; } - grad.gateGrad += frameSize * 4; - grad.stateGrad += frameSize; - grad.stateActiveGrad += frameSize; - grad.outputGrad += frameSize; + grad.gateGrad += frame_size * 4; + grad.stateGrad += frame_size; + grad.stateActiveGrad += frame_size; + grad.outputGrad += frame_size; if (grad.prevStateGrad) { - grad.prevStateGrad += frameSize; + grad.prevStateGrad += frame_size; } } - }; + } +}; + +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 2d7fccf1a0..ebf765c02e 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -14,7 +14,8 @@ limitations under the License. */ #pragma once -#include "paddle/platform/macros.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace operators { @@ -28,28 +29,28 @@ typedef enum { HL_ACTIVATION_END } activation_mode_t; -template -struct lstm_value { - real *gateValue; - real *prevStateValue; - real *stateValue; - real *stateActiveValue; - real *outputValue; - real *checkIg; - real *checkFg; - real *checkOg; +template +struct LstmMetaValue { + T *gateValue; + T *prevStateValue; + T *stateValue; + T *stateActiveValue; + T *outputValue; + T *checkIg; + T *checkFg; + T *checkOg; }; -template -struct lstm_grad { - real *gateGrad; - real *prevStateGrad; - real *stateGrad; - real *stateActiveGrad; - real *outputGrad; - real *checkIgGrad; - real *checkFgGrad; - real *checkOgGrad; +template +struct LstmMetaGrad { + T *gateGrad; + T *prevStateGrad; + T *stateGrad; + T *stateActiveGrad; + T *outputGrad; + T *checkIgGrad; + T *checkFgGrad; + T *checkOgGrad; }; activation_mode_t ActiveType(const std::string &type) { @@ -69,7 +70,8 @@ activation_mode_t ActiveType(const std::string &type) { template class LstmUnitFunctor { public: - static void compute(lstm_value value, int frame_size, int batch_size, + static void compute(const platform::DeviceContext &context, + LstmMetaValue value, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act); }; @@ -77,8 +79,9 @@ class LstmUnitFunctor { template class LstmUnitGradFunctor { public: - static void compute(lstm_value value, lstm_grad grad, int frame_size, - int batch_size, std::string gate_act, + static void compute(const platform::DeviceContext &context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act); }; diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index f4da949d4e..10c6e105b9 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -22,12 +22,14 @@ template class CopyMatrixRowsFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& src, const size_t* index, - framework::Tensor& dst, bool is_src_index) { + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index) { auto src_dims = src.dims(); auto dst_dims = dst.dims(); - PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2."); - PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, + "The dst must be matrix with rank 2."); PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], "The width of src and dst must be same."); auto height = dst_dims[0]; @@ -50,7 +52,9 @@ template class CopyMatrixRowsFunctor; template class CopyMatrixRowsFunctor; template class LoDTensor2BatchFunctor; -template class Batch2LoDTensor2Functor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index ecd05a30d3..e478c46db7 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -19,8 +19,8 @@ namespace operators { namespace math { template -__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const int* index, - int height, int width, +__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, + int64_t height, int64_t width, const bool is_src_index) { int idx = threadIdx.x; int idy = threadIdx.y; @@ -28,7 +28,7 @@ __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const int* index, while (id < height) { int src_idx = is_src_index ? index[id] : id; int dst_idx = is_src_index ? id : index[id]; - T* src_data = src + src_idx * width; + const T* src_data = src + src_idx * width; T* dst_data = dst + dst_idx * width; for (int i = idx; i < width; i += BlockDimX) { dst_data[i] = src_data[i]; @@ -41,12 +41,14 @@ template class CopyMatrixRowsFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& src, const size_t* index, - framework::Tensor& dst, bool is_src_index) { + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index) { auto src_dims = src.dims(); auto dst_dims = dst.dims(); - PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2."); - PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims.size(), 2, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2, + "The dst must be matrix with rank 2."); PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], "The width of src and dst must be same."); auto height = dst_dims[0]; @@ -56,9 +58,10 @@ class CopyMatrixRowsFunctor { dim3 threads(128, 8); dim3 grid(8, 1); - auto stream = reinterpret_cast(context); + auto stream = + reinterpret_cast(context).stream(); CopyMatrixRowsKernel<<>>( - src_data, dst_data, index, height, width); + src_data, dst_data, index, height, width, is_src_index); } }; @@ -66,7 +69,9 @@ template class CopyMatrixRowsFunctor; template class CopyMatrixRowsFunctor; template class LoDTensor2BatchFunctor; -template class Batch2LoDTensor2Functor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index e662292a02..3813d71238 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -12,6 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + namespace paddle { namespace operators { namespace math { @@ -25,8 +30,8 @@ class CopyMatrixRowsFunctor { // copy the input src to the indexed rows of output dst. // The indexed rows are based on the input index. void operator()(const platform::DeviceContext& context, - const framework::Tensor& src, const size_t* index, - framework::Tensor& dst, const bool is_src_index); + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, const bool is_src_index); }; template @@ -35,8 +40,8 @@ class LoDTensor2BatchFunctor { void operator()(const platform::DeviceContext& context, const framework::LoDTensor& lod_tensor, framework::LoDTensor& batch, const bool is_reverse) const { - auto lods = lod_tensor->lod(); - PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + auto lods = lod_tensor.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); auto lod = lods[0]; // Calculate the length of each sequence and @@ -47,7 +52,7 @@ class LoDTensor2BatchFunctor { // struct SeqInfo { SeqInfo(int start, int length, int seq_idx) - : start(start), length(length), seqIdx(seq_idx) {} + : start(start), length(length), seq_idx(seq_idx) {} int start; int length; int seq_idx; @@ -78,19 +83,19 @@ class LoDTensor2BatchFunctor { // The batch number represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. - auto batch_lods = batch->lod(); - if (!batch_lods) { - batch_lods->resize(2); + auto batch_lods = batch.lod(); + if (batch_lods.size() == 0) { + batch_lods.resize(2); } // batch_lods[0] is the start positions for batch LoDTensor int num_batch = (size_t)seq_info[0].length; - batch_lods[0]->resize(num_batch + 1); + batch_lods[0].resize(num_batch + 1); // batch_lods[1] is the raw index in the input LoDTensor - auto dims = lod_tensor->dims(); - batch_lods[1]->resize(dims[0]); + auto dims = lod_tensor.dims(); + batch_lods[1].resize(dims[0]); - auto* batch_starts = batch_lods[0].data(); - auto* seq2batch_idx = batch_lods[1].data(); + size_t* batch_starts = batch_lods[0].data(); + size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; for (size_t n = 0; n < num_batch; n++) { int batch_id = batch_starts[n]; @@ -112,17 +117,27 @@ class LoDTensor2BatchFunctor { } CopyMatrixRowsFunctor to_batch; - to_batch(context, lod_tensor, batch, true); + to_batch(context, lod_tensor, seq2batch_idx, batch, true); } }; template -class Batch2LoDTensor2Functor { +class Batch2LoDTensorFunctor { public: void operator()(const platform::DeviceContext& context, const framework::LoDTensor& batch, - framework::LoDTensor& lod_tensor, - const bool is_reverse) const; + framework::LoDTensor& lod_tensor) const { + auto in_lod = batch.lod(); + PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, + "The LoD size of input `batch` should be 2."); + auto out_lod = lod_tensor.lod(); + PADDLE_ENFORCE_EQ(out_lod[0][0], out_lod[1].size()); + PADDLE_ENFORCE_EQ(out_lod[0][0], lod_tensor.dims()[0]); + PADDLE_ENFORCE_EQ(out_lod[0][0], batch.dims()[0]); + CopyMatrixRowsFunctor to_seq; + size_t* index = out_lod[1].data(); + to_seq(context, batch, index, lod_tensor, false); + } }; } // namespace math diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py new file mode 100644 index 0000000000..f3f4c84b2a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -0,0 +1,116 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def identity(x): + return x + + +def sigmoid(x): + return 1. / (1. + np.exp(-x)) + + +def tanh(x): + return 2. * sigmoid(2. * x) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +def lstm( + input, # T x 4D + lod, # 1 x N + h0=None, # N x D + c0=None, # N x D + w_h=None, # D x 4D + w_b=None, # 1 x 4D + w_c=None, # 1 x 3D + is_reverse=False, + gate_act=None, + cell_act=None, + cand_act=None): + def _step(x, w_h, w_c, h_pre, c_pre, gate_act, cell_act, cand_act): + g = np.dot(h_pre, w_h) # 1 x 4D + g = g + x + g = np.reshape(g, (1, g.size)) + c, g_i, g_f, g_o = np.split(g, 4, axis=1) + if w_c is None: + g_i = gate_act(g_i) # 1 x D + g_f = gate_act(g_f) # 1 x D + else: + w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) + g_i = gate_act(g_i + w_ic * c_pre) # 1 x D + g_f = gate_act(g_f + w_fc * c_pre) # 1 x D + c = g_f * c_pre + g_i * cand_act(c) # 1 x D + + if w_c is None: + g_o = gate_act(g_o) # 1 x D + else: + _, _, w_oc = np.split(w_c, 3, axis=1) + g_o = gate_act(g_o + w_oc * c) # 1 x D + h = g_o * cell_act(c) + return h, c + + offset = lod[0] + batch_size = len(offset) - 1 + hidden = [] + cell = [] + if w_b is not None: + input = input + np.tile(w_b, (offset[-1], 1)) + for i in range(batch_size): + # compute one sequence + seq_len = offset[i + 1] - offset[i] + x = input[offset[i]:offset[i + 1], :] + h_pre = h0[i] # 1 x D + c_pre = h0[i] # 1 x D + for j in range(seq_len): + # compute one step + h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act, + cell_act, cand_act) + hidden.append(h_pre.flatten()) + cell.append(c_pre.flatten()) + + hidden = np.array(hidden).astype("float64") + cell = np.array(cell).astype("float64") + assert hidden.shape == (input.shape[0], input.shape[1] / 4) + assert cell.shape == (input.shape[0], input.shape[1] / 4) + return hidden, cell + + +class LstmUnitTest(OpTest): + def set_data(self): + lod = [[0, 2, 6, 9]] + shape = (9, 64) + + x = np.random.normal(size=(9, 4 * 64)).astype("float64") + h0 = np.random.normal(size=(4, 64)).astype("float64") + c0 = np.random.normal(size=(4, 64)).astype("float64") + w = np.random.normal(size=(64, 4 * 64)).astype("float64") + b = np.random.normal(size=(1, 7 * 64)).astype("float64") + + w_b = b[:, 4 * 64] + w_c = b[:, 4 * 64:] + h, c = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh) + + self.inputs = {'Input': x, 'H0': h0, 'C0': c0, 'Weight': w, 'Bias': b} + self.inputs = {'Hidden': h, 'Cell': c} + self.attrs = { + 'usePeepholes': True, + 'isReverse': False, + 'gateActivation': 'sigmoid', + 'cellActivation': 'tanh', + 'candidateActivation': 'tanh' + } + + def setUp(self): + self.set_data() + self.op_type = "lstm" + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From d1fbf50b9ebab35ea84b33fd330ef8c1b4e79bd3 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 19 Oct 2017 11:28:46 +0800 Subject: [PATCH 4/9] Add unit testing for forwad implementation. --- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/lstm_op.cc | 8 +- paddle/operators/lstm_op.h | 19 ++-- paddle/operators/math/CMakeLists.txt | 6 +- .../operators/math/detail/hl_avx_functions.cc | 4 +- .../operators/math/detail/hl_cpu_functions.cc | 89 +++++++++++++++++++ paddle/operators/math/detail/hl_functions.h | 89 ++++--------------- .../operators/math/detail/lstm_gpu_kernel.h | 50 +++++------ paddle/operators/math/lstm_compute.cc | 2 + paddle/operators/math/lstm_compute.cu | 42 ++------- paddle/operators/math/lstm_compute.h | 2 +- paddle/operators/math/sequence2batch.h | 23 ++--- .../paddle/v2/framework/tests/test_lstm_op.py | 83 +++++++++++------ 13 files changed, 233 insertions(+), 186 deletions(-) create mode 100644 paddle/operators/math/detail/hl_cpu_functions.cc diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7ce774a285..0c53ed3cdc 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) -op_library(lstm_op DEPS sequence2batch) +op_library(lstm_op DEPS sequence2batch lstm_compute math_function) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 7a72a08c50..f360502e66 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -44,7 +44,7 @@ class LSTMOp : public framework::OperatorWithKernel { "should be the same."); } - int frame_size = x_dims[1]; + int frame_size = x_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2."); @@ -71,9 +71,9 @@ class LSTMOp : public framework::OperatorWithKernel { "4 * %d if diable peepholes connection", frame_size); } - ctx->SetOutputDim("Hidden", x_dims); - ctx->SetOutputDim("Cell", x_dims); - ctx->SetOutputDim("Batch", x_dims); + ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); + ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); + ctx->SetOutputDim("BatchGate", x_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 6924cba68f..affa44c6fb 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -52,9 +52,14 @@ class LSTMKernel : public framework::OpKernel { to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); auto in_dims = input->dims(); - int frame_size = in_dims[1]; + int frame_size = in_dims[1] / 4; + framework::DDim dims({in_dims[0], frame_size}); if (bias) { + // framework::Tensor cpu_t; + // cpu_t.mutable_data(in_dims, platform::CPUPlace()); + // cpu_t.CopyFrom(*batch_gate, platform::CPUPlace(), + // ctx.device_context()); Eigen::array extents({{1, 4 * frame_size}}); Eigen::array offsets({{0, 0}}); auto b = EigenMatrix::From(*bias); @@ -76,15 +81,14 @@ class LSTMKernel : public framework::OpKernel { lstm_value.prevStateValue = nullptr; framework::LoDTensor batch_out; - batch_out.mutable_data(in_dims, ctx.GetPlace()); + batch_out.mutable_data(dims, ctx.GetPlace()); framework::LoDTensor batch_cell; - batch_cell.mutable_data(in_dims, ctx.GetPlace()); + batch_cell.mutable_data(dims, ctx.GetPlace()); framework::LoDTensor batch_cell_pre_act; - batch_cell_pre_act.mutable_data(in_dims, ctx.GetPlace()); + batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); auto batch_lod = batch_gate->lod()[0]; int num_batch = batch_lod.size() - 1; - auto gate_act = ctx.Attr("gateActivation"); auto cell_act = ctx.Attr("cellActivation"); auto cand_act = ctx.Attr("candidateActivation"); @@ -125,9 +129,12 @@ class LSTMKernel : public framework::OpKernel { // restore the output hidden in LoDTensor from the batch hidden to_seq(ctx.device_context(), batch_out, *hidden_out); - batch_out.set_lod(batch_gate->lod()); + batch_cell.set_lod(batch_gate->lod()); // restore the output cell state in LoDTensor from the batch cell to_seq(ctx.device_context(), batch_cell, *cell_out); + + auto t = framework::EigenVector::Flatten(*batch_gate); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 794ffc3997..2771b5de40 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(detail) + if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) @@ -6,7 +8,7 @@ if(WITH_GPU) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) - nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context) + nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) @@ -14,7 +16,7 @@ else() cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) - cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context) + cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/hl_avx_functions.cc index 70e7d80304..415bac5d93 100644 --- a/paddle/operators/math/detail/hl_avx_functions.cc +++ b/paddle/operators/math/detail/hl_avx_functions.cc @@ -14,10 +14,12 @@ limitations under the License. */ #include #include "hl_functions.h" +// TODO(qingqing) refine this dependence +#include "paddle/cuda/src/avx_mathfun.h" namespace hppl { -extern __m256 exp(__m256 a); +__m256 exp(__m256 a) { return exp256_ps(a); } __m256 relu(const __m256 a) { __m256 tmp = _mm256_set1_ps(0.0f); diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc new file mode 100644 index 0000000000..21ec78f962 --- /dev/null +++ b/paddle/operators/math/detail/hl_cpu_functions.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" + +namespace hppl { +namespace typef { + +float relu(const float a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +float linear(const float a) { return a; } + +float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } + +float sigmoid(const float a, const float b) { + return a * b * (static_cast(1) - b); +} + +float tanh(const float a, const float b) { + return a * (static_cast(1) - b * b); +} + +float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { +double relu(const double a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +double linear(const double a) { return a; } + +double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +double sigmoid(const double a, const double b) { + return a * b * (static_cast(1) - b); +} + +double tanh(const double a, const double b) { + return a * (static_cast(1) - b * b); +} + +double linear(const double a, const double b) { return a; } + +} // namespace typed +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h index c77c119dfe..3e2f0c9ee6 100644 --- a/paddle/operators/math/detail/hl_functions.h +++ b/paddle/operators/math/detail/hl_functions.h @@ -34,83 +34,28 @@ limitations under the License. */ #ifndef __NVCC__ namespace hppl { namespace typef { -/* - * forward activation - */ -float relu(const float a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -float sigmoid(const float a) { - const float min = SIGMOID_THRESHOLD_MIN; - const float max = SIGMOID_THRESHOLD_MAX; - float tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -float tanh(const float a) { - float tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -float linear(const float a) { return a; } - -/* - * backward activation - */ -float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } +float relu(const float a); +float sigmoid(const float a); +float tanh(const float a); +float linear(const float a); -float sigmoid(const float a, const float b) { - return a * b * (static_cast(1) - b); -} +float relu(const float a, const float b); +float sigmoid(const float a, const float b); +float tanh(const float a, const float b); +float linear(const float a, const float b); -float tanh(const float a, const float b) { - return a * (static_cast(1) - b * b); -} - -float linear(const float a, const float b) { return a; } } // namespace typef namespace typed { -/* - * forward activation - */ -double relu(const double a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -double sigmoid(const double a) { - const double min = SIGMOID_THRESHOLD_MIN; - const double max = SIGMOID_THRESHOLD_MAX; - double tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -double tanh(const double a) { - double tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -double linear(const double a) { return a; } - -/* - * backward activation - */ -double relu(const double a, const double b) { - return a * (b > 0.0 ? 1.0 : 0.0); -} - -double sigmoid(const double a, const double b) { - return a * b * (static_cast(1) - b); -} - -double tanh(const double a, const double b) { - return a * (static_cast(1) - b * b); -} - -double linear(const double a, const double b) { return a; } +double relu(const double a); +double sigmoid(const double a); +double tanh(const double a); +double linear(const double a); + +double relu(const double a, const double b); +double sigmoid(const double a, const double b); +double tanh(const double a, const double b); +double linear(const double a, const double b); } // namespace typed } // namespace hppl diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 01310a49f8..36f3030348 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/platform/cuda_helper.h" #include "paddle/platform/device_context.h" +#include + namespace paddle { namespace operators { namespace math { @@ -29,11 +31,10 @@ namespace detail { * grid(frameBlocks, batchBlocks) */ template -__global__ void KeLstmForward( - Op op, LstmMetaValue value, int frameSize, int batchSize, - typename hppl::ForwardActType::type active_node, - typename hppl::ForwardActType::type active_gate, - typename hppl::ForwardActType::type active_state) { +__global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -69,8 +70,10 @@ __global__ void KeLstmForward( rPrevState = value.prevStateValue[frameIdx]; } + hppl::gpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -87,11 +90,11 @@ __global__ void KeLstmForward( * grid(frameBlocks, batchBlocks) */ template -__global__ void KeLstmBackward( - Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, - int batchSize, typename hppl::BackwardActType::type active_node, - typename hppl::BackwardActType::type active_gate, - typename hppl::BackwardActType::type active_state) { +__global__ void KeLstmBackward(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -142,10 +145,11 @@ __global__ void KeLstmBackward( rPrevState = value.prevStateValue[frameIdx]; } + hppl::gpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - active_node, active_gate, active_state); + act(active_node), act(active_gate), act(active_state)); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -196,22 +200,16 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } - using type = typename hppl::ForwardActType::type; - hppl::gpu::ForwardAct act; - type act_node = act(active_node); - type act_gate = act(active_gate); - type act_state = act(active_state); - auto stream = reinterpret_cast(context).stream(); if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, act_node, act_gate, act_state); + op, value, frameSize, batchSize, active_node, active_gate, active_gate); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, act_node, act_gate, act_state); + op, value, frameSize, batchSize, active_node, active_gate, active_gate); } } @@ -235,22 +233,18 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } - using type = typename hppl::BackwardActType::type; - hppl::gpu::BackwardAct act; - type act_node = act(active_node); - type act_gate = act(active_gate); - type act_state = act(active_state); - auto stream = reinterpret_cast(context).stream(); if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, act_node, act_gate, act_state); + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, act_node, act_gate, act_state); + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); } } diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index 293c9da3a0..d1c63bafe1 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -72,6 +72,8 @@ struct LstmUnitGradFunctor { }; template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; template class LstmUnitGradFunctor; } // namespace math diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index aade604b9e..d942f60a26 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -26,18 +26,9 @@ struct LstmUnitFunctor { LstmMetaValue value, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { - for (int b = 0; b < batch_size; b++) { - detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, ActiveType(cand_act), - ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; - } - } + detail::gpu_lstm_forward(context, detail::forward::lstm(), value, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); } }; @@ -47,32 +38,15 @@ struct LstmUnitGradFunctor { LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, std::string gate_act, std::string cell_act, std::string cand_act) { - for (int b = 0; b < batch_size; b++) { - detail::gpu_lstm_backward(context, detail::backward::lstm(), value, - grad, frame_size, batch_size, - ActiveType(cand_act), ActiveType(gate_act), - ActiveType(cell_act)); - - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; - } - - grad.gateGrad += frame_size * 4; - grad.stateGrad += frame_size; - grad.stateActiveGrad += frame_size; - grad.outputGrad += frame_size; - if (grad.prevStateGrad) { - grad.prevStateGrad += frame_size; - } - } + detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); } }; template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; template class LstmUnitGradFunctor; } // namespace math diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index ebf765c02e..bff9dd3ea4 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -53,7 +53,7 @@ struct LstmMetaGrad { T *checkOgGrad; }; -activation_mode_t ActiveType(const std::string &type) { +inline activation_mode_t ActiveType(const std::string &type) { if (type == "sigmoid") { return HL_ACTIVATION_SIGMOID; } else if (type == "relu") { diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 3813d71238..89b5116804 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -59,7 +59,7 @@ class LoDTensor2BatchFunctor { }; std::vector seq_info; - for (size_t seq_id = 0; seq_id < lod.size(); ++seq_id) { + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; seq_info.emplace_back(lod[seq_id], length, seq_id); } @@ -83,10 +83,11 @@ class LoDTensor2BatchFunctor { // The batch number represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. - auto batch_lods = batch.lod(); - if (batch_lods.size() == 0) { - batch_lods.resize(2); - } + + paddle::framework::LoD batch_lods; + batch_lods.push_back(std::vector{0}); + batch_lods.push_back(std::vector{0}); + // batch_lods[0] is the start positions for batch LoDTensor int num_batch = (size_t)seq_info[0].length; batch_lods[0].resize(num_batch + 1); @@ -115,6 +116,7 @@ class LoDTensor2BatchFunctor { } batch_starts[n + 1] = batch_id; } + batch.set_lod(batch_lods); CopyMatrixRowsFunctor to_batch; to_batch(context, lod_tensor, seq2batch_idx, batch, true); @@ -130,12 +132,13 @@ class Batch2LoDTensorFunctor { auto in_lod = batch.lod(); PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, "The LoD size of input `batch` should be 2."); - auto out_lod = lod_tensor.lod(); - PADDLE_ENFORCE_EQ(out_lod[0][0], out_lod[1].size()); - PADDLE_ENFORCE_EQ(out_lod[0][0], lod_tensor.dims()[0]); - PADDLE_ENFORCE_EQ(out_lod[0][0], batch.dims()[0]); + auto out_lod = lod_tensor.lod()[0]; + auto num = out_lod[out_lod.size() - 1]; + PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); + PADDLE_ENFORCE_EQ(num, in_lod[1].size()); + PADDLE_ENFORCE_EQ(num, batch.dims()[0]); CopyMatrixRowsFunctor to_seq; - size_t* index = out_lod[1].data(); + size_t* index = in_lod[1].data(); to_seq(context, batch, index, lod_tensor, false); } }; diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index f3f4c84b2a..aa6a21b547 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -2,17 +2,26 @@ import unittest import numpy as np from op_test import OpTest +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + def identity(x): return x def sigmoid(x): - return 1. / (1. + np.exp(-x)) + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) def tanh(x): - return 2. * sigmoid(2. * x) - 1. + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. def relu(x): @@ -35,7 +44,7 @@ def lstm( g = np.dot(h_pre, w_h) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) - c, g_i, g_f, g_o = np.split(g, 4, axis=1) + c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) if w_c is None: g_i = gate_act(g_i) # 1 x D g_f = gate_act(g_f) # 1 x D @@ -43,7 +52,7 @@ def lstm( w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) g_i = gate_act(g_i + w_ic * c_pre) # 1 x D g_f = gate_act(g_f + w_fc * c_pre) # 1 x D - c = g_f * c_pre + g_i * cand_act(c) # 1 x D + c = g_f * c_pre + g_i * cand_act(c_tmp) # 1 x D if w_c is None: g_o = gate_act(g_o) # 1 x D @@ -51,12 +60,14 @@ def lstm( _, _, w_oc = np.split(w_c, 3, axis=1) g_o = gate_act(g_o + w_oc * c) # 1 x D h = g_o * cell_act(c) - return h, c + bg = np.concatenate((cand_act(c_tmp), g_i, g_f, g_o), axis=1) + return h, c, bg offset = lod[0] batch_size = len(offset) - 1 hidden = [] cell = [] + gate = [] if w_b is not None: input = input + np.tile(w_b, (offset[-1], 1)) for i in range(batch_size): @@ -64,44 +75,62 @@ def lstm( seq_len = offset[i + 1] - offset[i] x = input[offset[i]:offset[i + 1], :] h_pre = h0[i] # 1 x D - c_pre = h0[i] # 1 x D + c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act, - cell_act, cand_act) + h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act, + cell_act, cand_act) hidden.append(h_pre.flatten()) cell.append(c_pre.flatten()) + gate.append(g_pre.flatten()) hidden = np.array(hidden).astype("float64") cell = np.array(cell).astype("float64") + gate = np.array(gate).astype("float64") + assert gate.shape == input.shape assert hidden.shape == (input.shape[0], input.shape[1] / 4) assert cell.shape == (input.shape[0], input.shape[1] / 4) - return hidden, cell + return hidden, cell, gate class LstmUnitTest(OpTest): def set_data(self): - lod = [[0, 2, 6, 9]] - shape = (9, 64) - - x = np.random.normal(size=(9, 4 * 64)).astype("float64") - h0 = np.random.normal(size=(4, 64)).astype("float64") - c0 = np.random.normal(size=(4, 64)).astype("float64") - w = np.random.normal(size=(64, 4 * 64)).astype("float64") - b = np.random.normal(size=(1, 7 * 64)).astype("float64") - - w_b = b[:, 4 * 64] - w_c = b[:, 4 * 64:] - h, c = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh) - - self.inputs = {'Input': x, 'H0': h0, 'C0': c0, 'Weight': w, 'Bias': b} - self.inputs = {'Hidden': h, 'Cell': c} + D = 4 + #lod = [[0, 2, 6, 9]] + lod = [[0, 1]] + shape = (1, D) + + x = np.random.normal(size=(1, 4 * D)).astype("float64") + h0 = np.zeros((4, D)).astype("float64") + c0 = np.zeros((4, D)).astype("float64") + w = np.random.normal(size=(D, 4 * D)).astype("float64") + b = np.random.normal(size=(1, 7 * D)).astype("float64") + + w_b = b[:, 0:4 * D] + w_c = b[:, 4 * D:] + #h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh) + h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, identity, identity, + identity) + + g_sort = np.zeros_like(x) + #idx = [2,6,0,3,7,1,4,8,5] + #for i, j in enumerate(idx): + # g_sort[i, :] = g[j, :] + + self.inputs = { + 'Input': (x, lod), + 'H0': h0, + 'C0': c0, + 'Weight': w, + 'Bias': b + } + self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} self.attrs = { 'usePeepholes': True, 'isReverse': False, - 'gateActivation': 'sigmoid', - 'cellActivation': 'tanh', - 'candidateActivation': 'tanh' + 'gateActivation': 'linear', + 'cellActivation': 'linear', + 'candidateActivation': 'linear' } def setUp(self): From a461bf139dc7d0d2c6e88d944df408b6578c7aa5 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 19 Oct 2017 11:42:53 +0800 Subject: [PATCH 5/9] Add missing file. --- paddle/framework/CMakeLists.txt | 7 ------- paddle/operators/lstm_op.cu | 23 +++++++++++++++++++++++ paddle/operators/math/CMakeLists.txt | 1 + 3 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 paddle/operators/lstm_op.cu diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e57bcfabf8..6e32a1c99b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -44,13 +44,6 @@ cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward) -set(EXECUTOR_TEST_OP elementwise_add_op gaussian_random_op feed_op fetch_op - mul_op sum_op squared_l2_distance_op fill_constant_op sgd_op mean_op) -if(WITH_GPU) - # nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) -else() - # cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) -endif() cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/operators/lstm_op.cu b/paddle/operators/lstm_op.cu new file mode 100644 index 0000000000..9ad5694155 --- /dev/null +++ b/paddle/operators/lstm_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/lstm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_GPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 0c48f0d050..5598669ef9 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -17,6 +17,7 @@ else() cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) + cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) endif() From 8bec26be039a43d584e7260fd46df2ea7cac705e Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 19 Oct 2017 17:28:21 +0800 Subject: [PATCH 6/9] Add missing file of math/detail/CMakeLists.txt --- paddle/operators/math/detail/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 paddle/operators/math/detail/CMakeLists.txt diff --git a/paddle/operators/math/detail/CMakeLists.txt b/paddle/operators/math/detail/CMakeLists.txt new file mode 100644 index 0000000000..49cf228de2 --- /dev/null +++ b/paddle/operators/math/detail/CMakeLists.txt @@ -0,0 +1,5 @@ +if(WITH_AVX) + cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc) +else() + cc_library(activation_functions SRCS hl_cpu_functions.cc) +endif() From 17e33738f2f50c0417a4faf9dddd0c39cde17031 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 19 Oct 2017 22:13:02 +0800 Subject: [PATCH 7/9] Enhance unit testing and fix bug. --- paddle/operators/lstm_op.h | 17 +-- .../operators/math/detail/hl_gpu_functions.h | 6 +- .../operators/math/detail/lstm_gpu_kernel.h | 6 +- paddle/operators/math/lstm_compute.h | 2 +- python/paddle/v2/framework/tests/op_test.py | 4 +- .../paddle/v2/framework/tests/test_lstm_op.py | 128 ++++++++++++------ 6 files changed, 101 insertions(+), 62 deletions(-) diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index affa44c6fb..b9d4ae3a6f 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -56,10 +56,6 @@ class LSTMKernel : public framework::OpKernel { framework::DDim dims({in_dims[0], frame_size}); if (bias) { - // framework::Tensor cpu_t; - // cpu_t.mutable_data(in_dims, platform::CPUPlace()); - // cpu_t.CopyFrom(*batch_gate, platform::CPUPlace(), - // ctx.device_context()); Eigen::array extents({{1, 4 * frame_size}}); Eigen::array offsets({{0, 0}}); auto b = EigenMatrix::From(*bias); @@ -105,14 +101,14 @@ class LSTMKernel : public framework::OpKernel { int cur_batch_size = bend - bstart; if (n != 0) { - int pre_end = batch_lod[n - 1]; - auto pre_hidden_t = batch_out.Slice(pre_end, bstart); + int pre_h_start = batch_lod[n - 1]; + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); math::matmul(ctx.device_context(), pre_hidden_t, false, *weight, false, static_cast(1.0), &gate_t, - static_cast(0.0)); + static_cast(1.0)); } - // else if : how to pass the state from - // last mini-batch will be supported later + // else if : support the initial hidden and cell lstm_value.gateValue = gate_t.data(); lstm_value.outputValue = out_t.data(); @@ -132,9 +128,6 @@ class LSTMKernel : public framework::OpKernel { batch_cell.set_lod(batch_gate->lod()); // restore the output cell state in LoDTensor from the batch cell to_seq(ctx.device_context(), batch_cell, *cell_out); - - auto t = framework::EigenVector::Flatten(*batch_gate); - t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); } }; diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h index eee93dd578..72f2204e7b 100644 --- a/paddle/operators/math/detail/hl_gpu_functions.h +++ b/paddle/operators/math/detail/hl_gpu_functions.h @@ -30,7 +30,9 @@ __device__ static float sigmoid(const float a) { } __device__ static float tanh(const float a) { - return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f; + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f; } __device__ static float linear(const float a) { return a; } @@ -63,6 +65,8 @@ __device__ static double sigmoid(const double a) { } __device__ static double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 36f3030348..9573eaefb6 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -205,11 +205,13 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, active_gate); + op, value, frameSize, batchSize, active_node, active_gate, + active_state); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, active_gate); + op, value, frameSize, batchSize, active_node, active_gate, + active_state); } } diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index bff9dd3ea4..c58a1ad0d6 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -60,7 +60,7 @@ inline activation_mode_t ActiveType(const std::string &type) { return HL_ACTIVATION_RELU; } else if (type == "tanh") { return HL_ACTIVATION_TANH; - } else if (type == "linear" || type == "") { + } else if (type == "linear" || type == "identity" || type == "") { return HL_ACTIVATION_LINEAR; } else { PADDLE_THROW("Do not support activation type."); diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 215fa0b94e..169052fe41 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -242,7 +242,7 @@ class OpTest(unittest.TestCase): self.assertTrue( np.allclose( actual, expect, atol=atol), - "output name: " + out_name + " has diff.") + "Output (" + out_name + ") has diff at " + str(place)) else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] @@ -250,7 +250,7 @@ class OpTest(unittest.TestCase): self.assertTrue( np.allclose( actual, expect, atol=atol), - "output name: " + out_name + " has diff.") + "Output (" + out_name + ") has diff at " + str(place)) def check_output(self, atol=1e-5): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index aa6a21b547..bcce8d32c9 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -28,6 +28,14 @@ def relu(x): return np.maximum(x, 0) +ACTVATION = { + 'identity': identity, + 'sigmoid': sigmoid, + 'tanh': tanh, + 'relu': relu +} + + def lstm( input, # T x 4D lod, # 1 x N @@ -37,37 +45,45 @@ def lstm( w_b=None, # 1 x 4D w_c=None, # 1 x 3D is_reverse=False, - gate_act=None, - cell_act=None, - cand_act=None): - def _step(x, w_h, w_c, h_pre, c_pre, gate_act, cell_act, cand_act): + act_gate=None, + act_cell=None, + act_cand=None): + def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand): g = np.dot(h_pre, w_h) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) if w_c is None: - g_i = gate_act(g_i) # 1 x D - g_f = gate_act(g_f) # 1 x D + g_i = act_gate(g_i) # 1 x D + g_f = act_gate(g_f) # 1 x D else: w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) - g_i = gate_act(g_i + w_ic * c_pre) # 1 x D - g_f = gate_act(g_f + w_fc * c_pre) # 1 x D - c = g_f * c_pre + g_i * cand_act(c_tmp) # 1 x D + g_i = act_gate(g_i + w_ic * c_pre) # 1 x D + g_f = act_gate(g_f + w_fc * c_pre) # 1 x D + c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D if w_c is None: - g_o = gate_act(g_o) # 1 x D + g_o = act_gate(g_o) # 1 x D else: _, _, w_oc = np.split(w_c, 3, axis=1) - g_o = gate_act(g_o + w_oc * c) # 1 x D - h = g_o * cell_act(c) - bg = np.concatenate((cand_act(c_tmp), g_i, g_f, g_o), axis=1) + g_o = act_gate(g_o + w_oc * c) # 1 x D + h = g_o * act_cell(c) + bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1) return h, c, bg + def _reverse(x, lod): + y = np.zeros_like(x) + for i in range(len(lod) - 1): + b, e = lod[i], lod[i + 1] + y[b:e, :] = np.flip(x[b:e, :], 0) + return y + offset = lod[0] batch_size = len(offset) - 1 hidden = [] cell = [] gate = [] + input = _reverse(input, offset) if is_reverse else input if w_b is not None: input = input + np.tile(w_b, (offset[-1], 1)) for i in range(batch_size): @@ -78,8 +94,8 @@ def lstm( c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, gate_act, - cell_act, cand_act) + h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, + act_cell, act_cand) hidden.append(h_pre.flatten()) cell.append(c_pre.flatten()) gate.append(g_pre.flatten()) @@ -87,38 +103,53 @@ def lstm( hidden = np.array(hidden).astype("float64") cell = np.array(cell).astype("float64") gate = np.array(gate).astype("float64") + + hidden = _reverse(hidden, offset) if is_reverse else hidden + cell = _reverse(cell, offset) if is_reverse else cell + assert gate.shape == input.shape assert hidden.shape == (input.shape[0], input.shape[1] / 4) assert cell.shape == (input.shape[0], input.shape[1] / 4) return hidden, cell, gate -class LstmUnitTest(OpTest): +class TestLstmOp(OpTest): def set_data(self): - D = 4 - #lod = [[0, 2, 6, 9]] - lod = [[0, 1]] - shape = (1, D) - - x = np.random.normal(size=(1, 4 * D)).astype("float64") - h0 = np.zeros((4, D)).astype("float64") - c0 = np.zeros((4, D)).astype("float64") - w = np.random.normal(size=(D, 4 * D)).astype("float64") - b = np.random.normal(size=(1, 7 * D)).astype("float64") - - w_b = b[:, 0:4 * D] - w_c = b[:, 4 * D:] - #h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh) - h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, identity, identity, - identity) + self.lod = [[0, 2, 6, 9]] + self.D = 64 + self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + + self.act_gate = "sigmoid" + self.act_cell = "tanh" + self.act_cand = "tanh" + + self.is_reverse = False + + def setUp(self): + self.set_data() + self.op_type = "lstm" + + T = self.lod[0][-1] + N = len(self.lod[0]) - 1 + + x = np.random.normal(size=(T, 4 * self.D)).astype("float64") + h0 = np.zeros((N, self.D)).astype("float64") + c0 = np.zeros((N, self.D)).astype("float64") + w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64") + b = np.random.normal(size=(1, 7 * self.D)).astype("float64") + + w_b = b[:, 0:4 * self.D] + w_c = b[:, 4 * self.D:] + h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, + ACTVATION[self.act_gate], ACTVATION[self.act_cell], + ACTVATION[self.act_cand]) g_sort = np.zeros_like(x) - #idx = [2,6,0,3,7,1,4,8,5] - #for i, j in enumerate(idx): - # g_sort[i, :] = g[j, :] + for i, j in enumerate(self.sort_idx): + g_sort[i, :] = g[j, :] self.inputs = { - 'Input': (x, lod), + 'Input': (x, self.lod), 'H0': h0, 'C0': c0, 'Weight': w, @@ -127,19 +158,28 @@ class LstmUnitTest(OpTest): self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} self.attrs = { 'usePeepholes': True, - 'isReverse': False, - 'gateActivation': 'linear', - 'cellActivation': 'linear', - 'candidateActivation': 'linear' + 'isReverse': self.is_reverse, + 'gateActivation': 'sigmoid', + 'cellActivation': 'tanh', + 'candidateActivation': 'tanh' } - def setUp(self): - self.set_data() - self.op_type = "lstm" - def test_check_output(self): self.check_output() +class TestLstmOpRerverse(TestLstmOp): + def set_data(self): + self.lod = [[0, 2, 6, 9]] + self.D = 64 + self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + + self.act_gate = "sigmoid" + self.act_cell = "tanh" + self.act_cand = "tanh" + + self.is_reverse = True + + if __name__ == "__main__": unittest.main() From 65906ef1d0782e76b3bc40c09df30a01c423fb7c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 20 Oct 2017 12:52:35 -0700 Subject: [PATCH 8/9] Several Enhancement --- paddle/operators/lstm_op.cc | 16 ++--- paddle/operators/lstm_op.h | 18 ++--- paddle/operators/math/detail/lstm_kernel.h | 83 +++++++++++----------- paddle/operators/math/lstm_compute.cc | 9 +-- paddle/operators/math/lstm_compute.cu | 9 +-- paddle/operators/math/lstm_compute.h | 9 +-- paddle/operators/math/sequence2batch.cc | 2 - paddle/operators/math/sequence2batch.cu | 2 +- paddle/operators/math/sequence2batch.h | 51 ++++++------- 9 files changed, 102 insertions(+), 97 deletions(-) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index f360502e66..222aeeace5 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel { } else { PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, "The second dimension of Input(Bias) should be " - "4 * %d if diable peepholes connection", + "4 * %d if disable peepholes connection", frame_size); } ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); @@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(LoDTensor) the first input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " - "this LoDTenosr is a matrix with shape (T X 4D), where, T is the " + "this LoDTensor is a matrix with shape (T X 4D), where, T is the " "total time steps in this mini-batch, D is the hidden size."); AddInput("H0", "(Tensor, optional) the initial hidden state is an optional " @@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " - "and output gate aftern the nonlinear computation. This " + "and output gate after the nonlinear computation. This " "LoDTensor has the same shape with the reorganized input, which " "was also be called batch input. The LoD size is 2. The first " "LoD is the batch offsets and the second LoD contains the " @@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(false); AddAttr( "gateActivation", - "(string, defalut: sigmoid)" + "(string, default: sigmoid)" "The activation for input gate, forget gate and output " - "gate, `sigmoid` by defalut.") + "gate, `sigmoid` by default.") .SetDefault("sigmoid"); AddAttr("cellActivation", - "(string, defalut: tanh)" + "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh"); AddAttr("candidateActivation", - "(string, defalut: tanh)" + "(string, default: tanh)" "The activation for candidate hidden state, " - "`tanh` by defalut.") + "`tanh` by default.") .SetDefault("tanh"); AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index b9d4ae3a6f..5e10036707 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel { to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); auto in_dims = input->dims(); - int frame_size = in_dims[1] / 4; + int frame_size = static_cast(in_dims[1] / 4); framework::DDim dims({in_dims[0], frame_size}); if (bias) { @@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel { math::LstmMetaValue lstm_value; T* bias_data = const_cast(bias->data()); - // the code styple in LstmMetaValue will be updated later. + // the code style in LstmMetaValue will be updated later. lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.checkOg = lstm_value.checkFg + frame_size; @@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel { framework::LoDTensor batch_cell_pre_act; batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); - auto batch_lod = batch_gate->lod()[0]; - int num_batch = batch_lod.size() - 1; + auto& batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; auto gate_act = ctx.Attr("gateActivation"); auto cell_act = ctx.Attr("cellActivation"); auto cand_act = ctx.Attr("candidateActivation"); - for (int n = 0; n < num_batch; n++) { - int bstart = batch_lod[n]; - int bend = batch_lod[n + 1]; + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor out_t = batch_out.Slice(bstart, bend); @@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel { int cur_batch_size = bend - bstart; if (n != 0) { - int pre_h_start = batch_lod[n - 1]; + int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); math::matmul(ctx.device_context(), pre_hidden_t, false, *weight, false, static_cast(1.0), &gate_t, static_cast(1.0)); } - // else if : support the initial hidden and cell + // else if : FIXME support the initial hidden and cell lstm_value.gateValue = gate_t.data(); lstm_value.outputValue = out_t.data(); diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index b1e59a4ee8..6f3ead2397 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -13,12 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/platform/hostdevice.h" -#ifdef __CUDA_ARCH__ -#define INLINE __device__ inline -#else -#define INLINE inline -#endif +#include namespace paddle { namespace operators { @@ -30,12 +27,12 @@ namespace forward { template class lstm { public: - INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &prevState, T &state, T &stateAtv, T &output, - T &checkI, T &checkF, T &checkO, - typename hppl::ForwardActType::type actInput, - typename hppl::ForwardActType::type actGate, - typename hppl::ForwardActType::type actState) { + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &prevState, T &state, T &stateAtv, T &output, + T &checkI, T &checkF, T &checkO, + typename hppl::ForwardActType::type actInput, + typename hppl::ForwardActType::type actGate, + typename hppl::ForwardActType::type actState) { valueIn = actInput(valueIn); valueIg = actGate(valueIg + prevState * checkI); valueFg = actGate(valueFg + prevState * checkF); @@ -45,17 +42,19 @@ class lstm { output = valueOg * stateAtv; } #ifndef __NVCC__ -#ifndef __AVX__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default static const bool avx = false; #else - static const bool avx = true; - INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &prevState, __m256 &state, - __m256 &stateAtv, __m256 &output, __m256 &checkI, - __m256 &checkF, __m256 &checkO, - hppl::Active<__m256>::forward actInput, - hppl::Active<__m256>::forward actGate, - hppl::Active<__m256>::forward actState) { + // Only float support AVX optimization + static const bool avx = std::is_same::value; + + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &prevState, __m256 &state, + __m256 &stateAtv, __m256 &output, __m256 &checkI, + __m256 &checkF, __m256 &checkO, + hppl::Active<__m256>::forward actInput, + hppl::Active<__m256>::forward actGate, + hppl::Active<__m256>::forward actState) { valueIn = actInput(valueIn); valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); @@ -76,14 +75,15 @@ namespace backward { template class lstm { public: - INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &gradIn, T &gradIg, T &gradFg, T &gradOg, - T &prevState, T &prevStateGrad, T &state, T &stateGrad, - T &stateAtv, T &outputGrad, T &checkI, T &checkF, - T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, - typename hppl::BackwardActType::type actInput, - typename hppl::BackwardActType::type actGate, - typename hppl::BackwardActType::type actState) { + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &gradIn, T &gradIg, T &gradFg, T &gradOg, + T &prevState, T &prevStateGrad, T &state, + T &stateGrad, T &stateAtv, T &outputGrad, + T &checkI, T &checkF, T &checkO, T &checkIGrad, + T &checkFGrad, T &checkOGrad, + typename hppl::BackwardActType::type actInput, + typename hppl::BackwardActType::type actGate, + typename hppl::BackwardActType::type actState) { gradOg = actGate(outputGrad * stateAtv, valueOg); stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradIn = actInput(stateGrad * valueIg, valueIn); @@ -95,21 +95,22 @@ class lstm { checkOGrad = gradOg * state; } #ifndef __NVCC__ -#ifndef __AVX__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default static const bool avx = false; #else - static const bool avx = true; - INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, - __m256 &gradFg, __m256 &gradOg, __m256 &prevState, - __m256 &prevStateGrad, __m256 &state, - __m256 &stateGrad, __m256 &stateAtv, - __m256 &outputGrad, __m256 &checkI, __m256 &checkF, - __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, - __m256 &checkOGrad, - hppl::Active<__m256>::backward actInput, - hppl::Active<__m256>::backward actGate, - hppl::Active<__m256>::backward actState) { + // Only float support AVX optimization + static const bool avx = std::is_same::value; + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, + __m256 &gradFg, __m256 &gradOg, __m256 &prevState, + __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, + __m256 &outputGrad, __m256 &checkI, __m256 &checkF, + __m256 &checkO, __m256 &checkIGrad, + __m256 &checkFGrad, __m256 &checkOGrad, + hppl::Active<__m256>::backward actInput, + hppl::Active<__m256>::backward actGate, + hppl::Active<__m256>::backward actState) { gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); stateGrad = _mm256_add_ps( actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index d1c63bafe1..0febf8e3b7 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -24,8 +24,8 @@ template struct LstmUnitFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - std::string gate_act, std::string cell_act, - std::string cand_act) { + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, ActiveType(cand_act), ActiveType(gate_act), @@ -45,8 +45,9 @@ template struct LstmUnitGradFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, std::string gate_act, - std::string cell_act, std::string cand_act) { + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, frame_size, ActiveType(cand_act), diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu index d942f60a26..b2122f2a5c 100644 --- a/paddle/operators/math/lstm_compute.cu +++ b/paddle/operators/math/lstm_compute.cu @@ -24,8 +24,8 @@ template struct LstmUnitFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - std::string gate_act, std::string cell_act, - std::string cand_act) { + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, frame_size, batch_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); @@ -36,8 +36,9 @@ template struct LstmUnitGradFunctor { static void compute(const platform::DeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, std::string gate_act, - std::string cell_act, std::string cand_act) { + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, frame_size, batch_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index c58a1ad0d6..28d2c6fd3b 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -72,8 +72,8 @@ class LstmUnitFunctor { public: static void compute(const platform::DeviceContext &context, LstmMetaValue value, int frame_size, int batch_size, - std::string gate_act, std::string cell_act, - std::string cand_act); + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); }; template @@ -81,8 +81,9 @@ class LstmUnitGradFunctor { public: static void compute(const platform::DeviceContext &context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, std::string gate_act, - std::string cell_act, std::string cand_act); + int frame_size, int batch_size, + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); }; } // namespace math diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 10c6e105b9..00de56f7cd 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor { template class CopyMatrixRowsFunctor; template class CopyMatrixRowsFunctor; -template class LoDTensor2BatchFunctor; -template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index e478c46db7..4f34994678 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -21,7 +21,7 @@ namespace math { template __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, int64_t height, int64_t width, - const bool is_src_index) { + bool is_src_index) { int idx = threadIdx.x; int idy = threadIdx.y; int id = blockIdx.x + idy * GridDimX; diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 89b5116804..690cac0587 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -31,33 +31,33 @@ class CopyMatrixRowsFunctor { // The indexed rows are based on the input index. void operator()(const platform::DeviceContext& context, const framework::LoDTensor& src, const size_t* index, - framework::LoDTensor& dst, const bool is_src_index); + framework::LoDTensor& dst, bool is_src_index); }; template class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + public: void operator()(const platform::DeviceContext& context, const framework::LoDTensor& lod_tensor, - framework::LoDTensor& batch, const bool is_reverse) const { + framework::LoDTensor& batch, bool is_reverse) const { auto lods = lod_tensor.lod(); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); auto lod = lods[0]; - // Calculate the length of each sequence and - // sort sequence index by the length. - // example: sequences = {s0, s1, s2} - // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 - // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} - // - struct SeqInfo { - SeqInfo(int start, int length, int seq_idx) - : start(start), length(length), seq_idx(seq_idx) {} - int start; - int length; - int seq_idx; - }; - std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; @@ -75,31 +75,34 @@ class LoDTensor2BatchFunctor { // batchIndex = {b0, b1, b2, b3, b4} // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... // seq2batch_idx[12] = {4, 0, 9, // 5, 1, 10, // 6, 2, 11, // 7, 3, // 8} - // The batch number represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. paddle::framework::LoD batch_lods; - batch_lods.push_back(std::vector{0}); - batch_lods.push_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); // batch_lods[0] is the start positions for batch LoDTensor - int num_batch = (size_t)seq_info[0].length; - batch_lods[0].resize(num_batch + 1); + int num_batch = seq_info[0].length; + batch_lods[0].resize(static_cast(num_batch + 1)); // batch_lods[1] is the raw index in the input LoDTensor auto dims = lod_tensor.dims(); - batch_lods[1].resize(dims[0]); + batch_lods[1].resize(static_cast(dims[0])); size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; for (size_t n = 0; n < num_batch; n++) { - int batch_id = batch_starts[n]; + auto batch_id = static_cast(batch_starts[n]); for (size_t i = 0; i < seq_info.size(); ++i) { size_t seq_len = seq_info[i].length; int start = seq_info[i].start; @@ -114,7 +117,7 @@ class LoDTensor2BatchFunctor { break; } } - batch_starts[n + 1] = batch_id; + batch_starts[n + 1] = static_cast(batch_id); } batch.set_lod(batch_lods); From 64fe9bcc5c1dcbf90f54cb649f40c4e2a1f19ff0 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 23 Oct 2017 17:51:17 +0800 Subject: [PATCH 9/9] Update lstm comments and fix bug. --- paddle/framework/CMakeLists.txt | 3 ++- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/lstm_op.cc | 18 +++++++++--------- paddle/operators/lstm_op.h | 6 ++---- paddle/operators/math/sequence2batch.cc | 2 ++ 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6e32a1c99b..85752f5d6b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -20,7 +20,8 @@ proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info) -cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) +cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc +device_context) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0c53ed3cdc..f97bc837dc 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) -op_library(lstm_op DEPS sequence2batch lstm_compute math_function) +op_library(lstm_op DEPS sequence2batch lstm_compute) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 222aeeace5..0a089b7c2d 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -98,18 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "batch size. `H0` and `C0` can be NULL but only at the same time"); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." - " - The shape is (D x 4*D), where D is the hidden size. " - " - Weight = {W_ih, W_fh, W_ch, W_oh}"); + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); AddInput("Bias", "(Tensor) the learnable weights, which contains two parts: " "input-hidden bias weight and peephole connections weight if " - "seting `usePeepholes` True. " + "setting `usePeepholes` True. " "1. `usePeepholes = False` " - " - The shape is (1 x 4*D). " - " - Bias = {b_i, b_f, b_c, b_o}." + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." "2. `usePeepholes = True` " - " - The shape is (1 x 7*D). " - " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " "and output gate after the nonlinear computation. This " @@ -184,8 +184,8 @@ Set `usePeepholes` False to disable peephole connection [2]. The formula is omitted here. @note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ -operations on the input x_{t} were NOT included in this operator. The -users can choose to use fully-connect operator before LSTM operator. +operations on the input x_{t} were NOT included in this operator. +Users can choose to use fully-connect operator before LSTM operator. [1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory recurrent neural network architectures for large scale acoustic modeling. diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 5e10036707..b3e3db9726 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -76,14 +76,12 @@ class LSTMKernel : public framework::OpKernel { lstm_value.checkOg = lstm_value.checkFg + frame_size; lstm_value.prevStateValue = nullptr; - framework::LoDTensor batch_out; + framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; batch_out.mutable_data(dims, ctx.GetPlace()); - framework::LoDTensor batch_cell; batch_cell.mutable_data(dims, ctx.GetPlace()); - framework::LoDTensor batch_cell_pre_act; batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); - auto& batch_starts = batch_gate->lod()[0]; + auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; auto gate_act = ctx.Attr("gateActivation"); auto cell_act = ctx.Attr("cellActivation"); diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 00de56f7cd..10c6e105b9 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -51,6 +51,8 @@ class CopyMatrixRowsFunctor { template class CopyMatrixRowsFunctor; template class CopyMatrixRowsFunctor; +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor;