|
|
|
@ -15,10 +15,14 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/fusion_lstm_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/lstm_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_info.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(seq_mode, true, "Use sequence mode");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
ctx->ShareLoD("X", "Hidden");
|
|
|
|
|
ctx->ShareLoD("X", "Cell");
|
|
|
|
|
|
|
|
|
|
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
int xx_width;
|
|
|
|
|
if (FLAGS_seq_mode) {
|
|
|
|
|
xx_width = wx_dims[1];
|
|
|
|
|
} else {
|
|
|
|
|
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
|
|
|
|
|
ctx->ShareLoD("X", "XX");
|
|
|
|
|
}
|
|
|
|
@ -205,10 +214,34 @@ inline void ReorderInitState(const DeviceContext& ctx,
|
|
|
|
|
row_shuffle(ctx, src, index_lod, dst, indexed_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
void SeqCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX");
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH");
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto* xx = ctx.Output<LoDTensor>("XX");
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
auto wh_dims = wh->dims(); // D x 4D
|
|
|
|
|
const int M = x_dims[1]; // x frame size
|
|
|
|
|
const int D4 = wh_dims[1];
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* wx_data = wx->data<T>();
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
|
|
|
|
|
xx_data, bias->data<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX");
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH");
|
|
|
|
@ -339,6 +372,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
// restore the output cell state in LoDTensor from the batch cell
|
|
|
|
|
to_seq(dev_ctx, batch_cell, cell_out);
|
|
|
|
|
}
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
if (FLAGS_seq_mode) {
|
|
|
|
|
SeqCompute(ctx);
|
|
|
|
|
} else {
|
|
|
|
|
BatchCompute(ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -348,7 +388,5 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fusion_lstm,
|
|
|
|
|
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
|
|
|
|
|
ops::FuisonLSTMKernel<double>);
|
|
|
|
|