|
|
|
@ -15,12 +15,9 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/attention_lstm_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/lstm_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
// #include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
|
// #include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -233,6 +230,13 @@ use lstm_x_t as input and compute as standard LSTM.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void vec_relu(const int n, const T* x, T* y) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = x[i] > 0 ? x[i] : 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
|
|
|
|
@ -240,14 +244,14 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = x[i] + bias[0];
|
|
|
|
|
}
|
|
|
|
|
vec_relu(n, y, y);
|
|
|
|
|
vec_relu<T>(n, y, y);
|
|
|
|
|
} else {
|
|
|
|
|
vec_relu(n, x, y);
|
|
|
|
|
vec_relu<T>(n, x, y);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
|
|
|
|
|
inline void vec_softmax(const math::BlasT<DeviceContext, T>& blas, const int n,
|
|
|
|
|
const T* x, T* y) {
|
|
|
|
|
T scalar = x[0];
|
|
|
|
|
// max
|
|
|
|
@ -257,7 +261,7 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
|
|
|
|
|
|
|
|
|
|
// sub
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[c] = x[c] - alpha;
|
|
|
|
|
y[i] = x[i] - scalar;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// exp
|
|
|
|
@ -270,57 +274,45 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// scale
|
|
|
|
|
blas.VSCAL(n, static_cast<T>(1) / scalar, y);
|
|
|
|
|
blas.SCAL(n, static_cast<T>(1) / scalar, y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__m256 exp(__m256 a) { return exp256_ps(a); }
|
|
|
|
|
#define SIGMOID_THRESHOLD_MIN -40.0
|
|
|
|
|
#define SIGMOID_THRESHOLD_MAX 13.0
|
|
|
|
|
#define EXP_MAX_INPUT 40.0
|
|
|
|
|
|
|
|
|
|
__m256 log(__m256 a) { return log256_ps(a); }
|
|
|
|
|
|
|
|
|
|
__m256 sin(__m256 a) { return sin256_ps(a); }
|
|
|
|
|
|
|
|
|
|
__m256 cos(__m256 a) { return cos256_ps(a); }
|
|
|
|
|
|
|
|
|
|
__m256 relu(const __m256 a) {
|
|
|
|
|
__m256 tmp = _mm256_set1_ps(0.0f);
|
|
|
|
|
return _mm256_max_ps(a, tmp);
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline T sigmoid(T x) {
|
|
|
|
|
return 1. / (1. + exp(-x));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__m256 sigmoid(const __m256 a) {
|
|
|
|
|
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
|
|
|
|
|
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
|
|
|
|
|
__m256 tmp = _mm256_max_ps(a, min);
|
|
|
|
|
tmp = _mm256_min_ps(tmp, max);
|
|
|
|
|
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
|
|
|
|
|
tmp = exp(tmp);
|
|
|
|
|
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
|
|
|
|
|
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
|
|
|
|
|
return tmp;
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline T tanh(T x) {
|
|
|
|
|
return 2. * sigmoid(2. * x) - 1.;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__m256 tanh(const __m256 a) {
|
|
|
|
|
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
|
|
|
|
|
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
|
|
|
|
|
tmp = _mm256_min_ps(tmp, max);
|
|
|
|
|
tmp = exp(tmp);
|
|
|
|
|
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
|
|
|
|
|
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
|
|
|
|
|
_mm256_set1_ps(1.0f));
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void vec_sigmoid(const int n, const T* x, T* y) {
|
|
|
|
|
const T min = SIGMOID_THRESHOLD_MIN;
|
|
|
|
|
const T max = SIGMOID_THRESHOLD_MAX;
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
|
|
|
|
|
y[i] = 1.0 / (1.0 + std::exp(-tmp));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__m256 linear(const __m256 a) { return a; }
|
|
|
|
|
|
|
|
|
|
inline void vec_sigmoid(const T* x, T* y) {
|
|
|
|
|
const real min = SIGMOID_THRESHOLD_MIN;
|
|
|
|
|
const real max = SIGMOID_THRESHOLD_MAX;
|
|
|
|
|
real tmp = (a < min) ? min : ((a > max) ? max : a);
|
|
|
|
|
return 1.0 / (1.0 + exp(-tmp));
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void vec_tanh(const int n, const T* x, T* y) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = tanh<T>(x[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
template <typename T>
|
|
|
|
|
class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X"); // T x M
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); // N x D
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0"); // N x D
|
|
|
|
@ -334,7 +326,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
|
|
|
|
|
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut"); // max_seq_len x 1
|
|
|
|
|
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
|
|
|
|
|
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
|
|
|
|
|
|
|
|
|
@ -342,9 +334,10 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
|
const int N = x_lod[0].size() - 1; // batch size
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
auto w_dims = w->dims(); // (D+M) x 4D
|
|
|
|
|
const int M = x_dims[1]; // x frame size
|
|
|
|
|
const int D = w_dims[1] / 4; // gate frame size
|
|
|
|
|
auto w_dims = lstm_w->dims(); // (D+M) x 4D
|
|
|
|
|
const int total_T = x_dims[0];
|
|
|
|
|
const int M = x_dims[1]; // x frame size
|
|
|
|
|
const int D = w_dims[1] / 4; // gate frame size
|
|
|
|
|
const int D2 = D * 2;
|
|
|
|
|
const int D3 = D * 3;
|
|
|
|
|
const int D4 = w_dims[1];
|
|
|
|
@ -357,6 +350,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
|
|
|
|
|
fc_out->Resize({max_seq_len, 1});
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): act functor init here
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0->data<T>();
|
|
|
|
|
const T* c0_data = c0->data<T>();
|
|
|
|
@ -368,16 +363,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
const T* atten_scalar_bias_data =
|
|
|
|
|
atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;
|
|
|
|
|
|
|
|
|
|
T* hidden_out_data = hidden_out->mutable_data<T>();
|
|
|
|
|
T* cell_out_data = cell_out->mutable_data<T>();
|
|
|
|
|
T* atted_x_data = atted_x->mutable_data<T>();
|
|
|
|
|
T* fc_out_data = fc_out->mutable_data<T>();
|
|
|
|
|
T* lstm_x_data = lstm_x->mutable_data<T>();
|
|
|
|
|
T* lstm_out_data = lstm_out->mutable_data<T>();
|
|
|
|
|
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data,
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
|
|
|
|
|
atted_x_data, atten_b_data);
|
|
|
|
|
|
|
|
|
|
const T* cur_x_data = x_data;
|
|
|
|
@ -400,7 +395,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
// fc2: scalar
|
|
|
|
|
if (atten_scalar_data) {
|
|
|
|
|
// x = a*x
|
|
|
|
|
blas.SCAL(seq_len, atten_scalar_data, fc_out_data);
|
|
|
|
|
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
|
|
|
|
|
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
|
|
|
|
|
fc_out_data);
|
|
|
|
|
}
|
|
|
|
@ -431,16 +426,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
|
|
|
|
|
|
|
|
|
|
// b = input * tilde
|
|
|
|
|
blas.VMUL(D, lstm_out_data + D, lstm_out + D3, lstm_out_data + D);
|
|
|
|
|
blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
|
|
|
|
|
|
|
|
|
|
// cell_out = a + b
|
|
|
|
|
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
|
|
|
|
|
|
|
|
|
|
// state act tanh(cell_out) * output_gate
|
|
|
|
|
vec_tanh(D, cur_cell_out_data, lstm_out_data);
|
|
|
|
|
blas.VMUL(D, lstm_out_data, lstm_out + D2, cur_hidden_out_data);
|
|
|
|
|
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
|
|
|
|
|
|
|
|
|
|
prev_hidden_data = hidden_out + i * gate_size;
|
|
|
|
|
prev_hidden_data = cur_hidden_out_data;
|
|
|
|
|
prev_cell_data = cur_cell_out_data;
|
|
|
|
|
cur_cell_out_data = cur_cell_out_data + D;
|
|
|
|
|
cur_hidden_out_data = cur_hidden_out_data + D;
|
|
|
|
@ -458,7 +453,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
|
|
|
|
|
ops::AttentionLSTMOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
attention_lstm,
|
|
|
|
|
ops::AttentionLSTMKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::AttentionLSTMKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel<float>,
|
|
|
|
|
ops::AttentionLSTMKernel<double>);
|
|
|
|
|