|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/attention_lstm_op.h"
|
|
|
|
|
#include <sys/time.h>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
@ -192,24 +193,23 @@ void AttentionLSTMOpMaker::Make() {
|
|
|
|
|
"(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
|
|
|
|
|
"Shape is (1 x 4D), where M is the x frame size")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
// TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"});
|
|
|
|
|
AddAttr<std::string>("gate_activation",
|
|
|
|
|
"(string, default: sigmoid)"
|
|
|
|
|
"The activation for input gate, forget gate and output "
|
|
|
|
|
"gate, `sigmoid` by default.")
|
|
|
|
|
.SetDefault("sigmoid")
|
|
|
|
|
.InEnum({"sigmoid"});
|
|
|
|
|
.InEnum({"sigmoid", "tanh", "relu", "identity"});
|
|
|
|
|
AddAttr<std::string>("cell_activation",
|
|
|
|
|
"(string, default: tanh)"
|
|
|
|
|
"The activation for cell output, `tanh` by defalut.")
|
|
|
|
|
.SetDefault("tanh")
|
|
|
|
|
.InEnum({"tanh"});
|
|
|
|
|
.InEnum({"sigmoid", "tanh", "relu", "identity"});
|
|
|
|
|
AddAttr<std::string>("candidate_activation",
|
|
|
|
|
"(string, default: tanh)"
|
|
|
|
|
"The activation for candidate hidden state, "
|
|
|
|
|
"`tanh` by default.")
|
|
|
|
|
.SetDefault("tanh")
|
|
|
|
|
.InEnum({"tanh"});
|
|
|
|
|
.InEnum({"sigmoid", "tanh", "relu", "identity"});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Attention Long-Short Term Memory (LSTM) Operator.
|
|
|
|
|
|
|
|
|
@ -273,22 +273,23 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X"); // T x M
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); // N x D
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0"); // N x D
|
|
|
|
|
auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1
|
|
|
|
|
auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1
|
|
|
|
|
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1
|
|
|
|
|
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias"); // 1x1
|
|
|
|
|
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
|
|
|
|
|
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
|
|
|
|
|
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
|
|
|
|
|
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut"); // max_seq_len x 1
|
|
|
|
|
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
|
|
|
|
|
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
|
|
|
|
|
auto* atten_b = ctx.Input<Tensor>("AttentionBias");
|
|
|
|
|
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
|
|
|
|
|
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
|
|
|
|
|
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
|
|
|
|
|
auto* lstm_b = ctx.Input<Tensor>("LSTMBias");
|
|
|
|
|
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell");
|
|
|
|
|
auto* atted_x = ctx.Output<Tensor>("AttentionedX");
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
|
|
|
|
|
auto* lstm_x = ctx.Output<Tensor>("LSTMX");
|
|
|
|
|
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
|
|
|
|
|
|
|
|
|
|
// some shape should be reshape here since infershape can not get lod info
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
@ -310,11 +311,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
|
|
|
|
|
fc_out->Resize({max_seq_len, 1});
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): act functor init here
|
|
|
|
|
// if (platform::jit::MayIUse(platform::jit::avx2)) {
|
|
|
|
|
// } else if (platform::jit::MayIUse(platform::jit::avx)) {
|
|
|
|
|
// } else {
|
|
|
|
|
// }
|
|
|
|
|
math::VecActivations<T> act_functor;
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
|
|
|
|
|
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
|
|
|
|
|
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
|
|
|
|
|
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : NULL;
|
|
|
|
@ -381,9 +382,9 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);
|
|
|
|
|
|
|
|
|
|
// gate act: sigmoid
|
|
|
|
|
math::vec_sigmoid(D3, lstm_out_data, lstm_out_data);
|
|
|
|
|
act_gate(D3, lstm_out_data, lstm_out_data);
|
|
|
|
|
// candicate act: tanh
|
|
|
|
|
math::vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3);
|
|
|
|
|
act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
|
|
|
|
|
|
|
|
|
|
// a = forget * prev_cell
|
|
|
|
|
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
|
|
|
|
@ -395,7 +396,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
|
|
|
|
|
|
|
|
|
|
// state act tanh(cell_out) * output_gate
|
|
|
|
|
math::vec_tanh(D, cur_cell_out_data, lstm_out_data);
|
|
|
|
|
act_cell(D, cur_cell_out_data, lstm_out_data);
|
|
|
|
|
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
|
|
|
|
|
|
|
|
|
|
prev_hidden_data = cur_hidden_out_data;
|
|
|
|
|