parent
9efd5422f9
commit
8728b3cce2
@ -0,0 +1,185 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/lstm_unit_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LSTMOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
||||
"Output(Hidden) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("H"),
|
||||
"Output(Cell) of LSTM should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
||||
|
||||
if (ctx->HasInput("H0")) {
|
||||
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
||||
"Input(Cell) and Input(Hidden) of LSTM should not "
|
||||
"be null at the same time.");
|
||||
auto h_dims = ctx->GetInputDim("H0");
|
||||
auto c_dims = ctx->GetInputDim("C0");
|
||||
PADDLE_ENFORCE(h_dims == c_dims,
|
||||
"The dimension of Input(H0) and Input(C0) "
|
||||
"should be the same.");
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Hidden", x_dims);
|
||||
ctx->SetOutputDim("Cell", x_dims);
|
||||
ctx->ShareLoD("Input", "Hidden");
|
||||
ctx->ShareLoD("Input", "Cell");
|
||||
}
|
||||
};
|
||||
|
||||
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input",
|
||||
"(LoDTensor) the first input is a LodTensor, which support "
|
||||
"variable-time length input sequence. The underlying tensor in "
|
||||
"this LoDTenosr is a matrix with shape (T X D), where, T is the "
|
||||
"total time steps in this mini-batch, D is the hidden size.");
|
||||
AddInput("H0",
|
||||
"(Tensor, optional) the initial hidden state is an optional "
|
||||
"input. This is a tensor with shape (N x D), where N is the "
|
||||
"batch size, D is the hidden size.");
|
||||
AddInput("C0",
|
||||
"(Tensor, optional) the initial cell state is an optional "
|
||||
"input. This is a tensor with shape (N x D), where N is the "
|
||||
"batch size. `H0` and `C0` can be NULL but only at the same time");
|
||||
AddInput("Weight",
|
||||
"(Tensor) the learnable hidden-hidden weights."
|
||||
" - The shape is (D x 4*D), where D is the hidden size. "
|
||||
" - Weight = {W_ih, W_fh, W_ch, W_oh}");
|
||||
AddInput("Bias",
|
||||
"(Tensor) the learnable weights, which contains two parts: "
|
||||
"input-hidden bias weight and peephole connections weight if "
|
||||
"seting `use_peepholes` True. "
|
||||
"1. `use_peepholes = False` "
|
||||
" - The shape is (1 x 4*D). "
|
||||
" - Bias = {b_i, b_f, b_c, b_o}."
|
||||
"2. `use_peepholes = True` "
|
||||
" - The shape is (1 x 7*D). "
|
||||
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
|
||||
AddOutput("Hidden",
|
||||
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
|
||||
"The shape and lod is the same with the `Input`.");
|
||||
AddOutput("Cell",
|
||||
"(LoDTensor) the cell state lod tensor of LSTM operator. "
|
||||
"The shape and lod is the same with the `Input`.");
|
||||
AddAttr<bool>("use_peepholes",
|
||||
"(bool, defalut: True) "
|
||||
"whether to enable diagonal/peephole connections.")
|
||||
.SetDefault(true);
|
||||
AddAttr<std::string>(
|
||||
"gate_activation",
|
||||
"(string, defalut: sigmoid)"
|
||||
"The activation for input gate, forget gate and output "
|
||||
"gate, `sigmoid` by defalut.")
|
||||
.SetDefault("sigmoid");
|
||||
AddAttr<std::string>("cell_activation",
|
||||
"(string, defalut: tanh)"
|
||||
"The activation for cell output, `tanh` by defalut.")
|
||||
.SetDefault("tanh");
|
||||
AddAttr<std::string>("candidate_activation",
|
||||
"(string, defalut: tanh)"
|
||||
"The activation for candidate hidden state, "
|
||||
"`tanh` by defalut.")
|
||||
.SetDefault("tanh");
|
||||
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
|
||||
|
||||
The defalut implementation is diagonal/peephole connection [1], the formula is
|
||||
as follows
|
||||
|
||||
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i)
|
||||
|
||||
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f)
|
||||
|
||||
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
|
||||
|
||||
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o)
|
||||
|
||||
c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t}
|
||||
|
||||
h_t = o_t ⊙ act_h(c_t)
|
||||
|
||||
where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix
|
||||
of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$
|
||||
are diagonal weight matrices for peephole connections. In our implenmention,
|
||||
We use vectors to reprenset these diagonal weight matrices. The b terms
|
||||
denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$
|
||||
is the non-line actications, such as logistic sigmoid function, and
|
||||
\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate,
|
||||
output gate and cell activation vectors, all of which are the same size as
|
||||
the cell output activation vector \f$h\f$.
|
||||
|
||||
The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$
|
||||
are the cell input and cell output activation functions, `tanh` is usually
|
||||
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
|
||||
which is computed based on the current input and the previous hidden state.
|
||||
|
||||
Set `use_peepholes` False to disable peephole connection [2]. The formula
|
||||
is omitted here.
|
||||
|
||||
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
|
||||
operations on the input x_{t} were NOT included in this operator. The
|
||||
users can choose to use fully-connect operator before LSTM operator.
|
||||
|
||||
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
|
||||
recurrent neural network architectures for large scale acoustic modeling.
|
||||
INTERSPEECH, 2014.
|
||||
|
||||
[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory.
|
||||
Neural Computation, 9(8):1735-1780, 1997.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class LSTMGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
||||
"Input(Hidden@GRAD) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
|
||||
"Input(Cell@GRAD) should not be null");
|
||||
ctx->SetOutputDim(framework::GradVarName("Weight"),
|
||||
ctx->GetInputDim("Weight"));
|
||||
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::LSTMKernel<paddle::platform::CPUPlace, double>);
|
||||
REGISTER_OP_CPU_KERNEL(lstm_grad,
|
||||
ops::LSTMGradKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::LSTMGradKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,38 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LSTMKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LSTMGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence2batch.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
|
||||
template class Batch2LoDTensor2Functor<platform::CPUPlace, float>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence2batch.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class LoDTensor2BatchFunctor<platform::GPUPlace, float>;
|
||||
template class Batch2LoDTensor2Functor<platform::GPUPlace, float>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,113 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LoDTensor2BatchFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& lod_tensor,
|
||||
framework::LoDTensor& batch, const bool is_reverse) const {
|
||||
auto lods = lod_tensor->lod();
|
||||
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
|
||||
auto lod = lods[0];
|
||||
|
||||
// Calculate the length of each sequence and
|
||||
// sort sequence index by the length.
|
||||
// example: sequences = {s0, s1, s2}
|
||||
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
|
||||
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
|
||||
//
|
||||
struct SeqInfo {
|
||||
SeqInfo(int start, int length, int seq_idx)
|
||||
: start(start), length(length), seqIdx(seq_idx) {}
|
||||
int start;
|
||||
int length;
|
||||
int seq_idx;
|
||||
};
|
||||
|
||||
std::vector<SeqInfo> seq_info;
|
||||
for (size_t seq_id = 0; seq_id < lod.size(); ++seq_id) {
|
||||
int length = lod[seq_id + 1] - lod[seq_id];
|
||||
seq_info.emplace_back(lod[seq_id], length, seq_id);
|
||||
}
|
||||
|
||||
std::sort(seq_info.begin(), seq_info.end(),
|
||||
[](SeqInfo a, SeqInfo b) { return a.length > b.length; });
|
||||
|
||||
// calculate the start position of each batch
|
||||
// (numBatch equal the maxLength of sequences)
|
||||
// example: sequences = {s0, s1, s2}
|
||||
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
|
||||
// num_batch = 5,
|
||||
// batchIndex = {b0, b1, b2, b3, b4}
|
||||
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
|
||||
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
|
||||
// seq2batch_idx[12] = {4, 0, 9,
|
||||
// 5, 1, 10,
|
||||
// 6, 2, 11,
|
||||
// 7, 3,
|
||||
// 8}
|
||||
|
||||
// The batch number represents batch size after rearranging the
|
||||
// input LodTensor. It is also the maximum length of input sequence.
|
||||
auto batch_lods = batch->lod();
|
||||
if (!batch_lods) {
|
||||
batch_lods->resize(2);
|
||||
}
|
||||
// batch_lods[0] is the start positions for batch LoDTensor
|
||||
int num_batch = (size_t)seq_info[0].length;
|
||||
batch_lods[0]->resize(num_batch + 1);
|
||||
// batch_lods[1] is the raw index in the input LoDTensor
|
||||
auto dims = lod_tensor->dims();
|
||||
batch_lods[1]->resize(dims[0]);
|
||||
|
||||
auto* batch_starts = batch_lods[0].data();
|
||||
auto* seq2batch_idx = batch_lods[1].data();
|
||||
batch_starts[0] = 0;
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
int batch_id = batch_starts[n];
|
||||
for (size_t i = 0; i < seq_info.size(); ++i) {
|
||||
size_t seq_len = seq_info[i].length;
|
||||
int start = seq_info[i].start;
|
||||
if (n < seq_len) {
|
||||
if (!is_reverse) {
|
||||
seq2batch_idx[batch_id] = start + n;
|
||||
} else {
|
||||
seq2batch_idx[batch_id] = start + seq_len - 1 - n;
|
||||
}
|
||||
batch_id++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
batch_starts[n + 1] = batch_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename T>
|
||||
class Batch2LoDTensor2Functor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& batch,
|
||||
framework::LoDTensor& lod_tensor,
|
||||
const bool is_reverse) const;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue