|
|
|
@ -16,14 +16,10 @@ limitations under the License. */
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/lstm_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_info.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(seq_mode, true, "Use sequence mode");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -110,7 +106,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
ctx->ShareLoD("X", "Cell");
|
|
|
|
|
|
|
|
|
|
int xx_width;
|
|
|
|
|
if (FLAGS_seq_mode) {
|
|
|
|
|
if (ctx->Attrs().Get<bool>("use_seq")) {
|
|
|
|
|
xx_width = wx_dims[1];
|
|
|
|
|
} else {
|
|
|
|
|
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
@ -189,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
|
|
|
|
|
"(bool, defalut: False) "
|
|
|
|
|
"whether to compute reversed LSTM.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<bool>("use_seq",
|
|
|
|
|
"(bool, defalut: True) "
|
|
|
|
|
"whether to use seq mode to compute.")
|
|
|
|
|
.SetDefault(true);
|
|
|
|
|
AddAttr<std::string>("gate_activation",
|
|
|
|
|
"(string, default: sigmoid)"
|
|
|
|
|
"The activation for input gate, forget gate and output "
|
|
|
|
@ -264,8 +264,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int N = x_lod[0].size() - 1; // batch size
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : NULL;
|
|
|
|
|
const T* c0_data = c0 ? c0->data<T>() : NULL;
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : nullptr;
|
|
|
|
|
const T* c0_data = c0 ? c0->data<T>() : nullptr;
|
|
|
|
|
const T* wx_data = wx->data<T>();
|
|
|
|
|
const T* wh_data = wh->data<T>();
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -295,8 +295,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
int bid = is_reverse ? N - 1 - i : i;
|
|
|
|
|
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
|
|
|
|
|
const T* prev_c_data = NULL;
|
|
|
|
|
const T* prev_h_data = NULL;
|
|
|
|
|
const T* prev_c_data = nullptr;
|
|
|
|
|
const T* prev_h_data = nullptr;
|
|
|
|
|
int tstart = 0;
|
|
|
|
|
if (h0_data) {
|
|
|
|
|
prev_h_data = h0_data + bid * D;
|
|
|
|
@ -351,8 +351,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = platform::CPUDeviceContext;
|
|
|
|
|
INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
if (x->lod()[0].size() == 2) { // batch size == 1
|
|
|
|
|
if (x->lod()[0].size() == 2) {
|
|
|
|
|
SeqCompute(ctx);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
INIT_BASE_SIZES
|
|
|
|
|
INIT_VEC_FUNC
|
|
|
|
@ -396,8 +397,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
reordered_c0->Resize({max_bs, D});
|
|
|
|
|
|
|
|
|
|
int tstart = 0;
|
|
|
|
|
T* prev_h_data = NULL;
|
|
|
|
|
T* prev_c_data = NULL;
|
|
|
|
|
T* prev_h_data = nullptr;
|
|
|
|
|
T* prev_c_data = nullptr;
|
|
|
|
|
if (h0) {
|
|
|
|
|
// reorder h0, c0
|
|
|
|
|
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
|
|
|
|
@ -489,7 +490,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
if (FLAGS_seq_mode) {
|
|
|
|
|
if (ctx.Attr<bool>("use_seq")) {
|
|
|
|
|
SeqCompute(ctx);
|
|
|
|
|
} else {
|
|
|
|
|
BatchCompute(ctx);
|
|
|
|
|