|
|
|
@ -15,9 +15,9 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
|
|
|
|
|
#include <cstring> // for memcpy
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernels.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -183,27 +183,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int total_T = x_dims[0]; \
|
|
|
|
|
const int D3 = wh_dims[1]
|
|
|
|
|
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D2 = D * 2; \
|
|
|
|
|
const math::jitkernel::gru_attr_t attr( \
|
|
|
|
|
D, ctx.Attr<std::string>("gate_activation"), \
|
|
|
|
|
ctx.Attr<std::string>("activation")); \
|
|
|
|
|
math::jitkernel::gru_t one_step; \
|
|
|
|
|
const auto& ker = \
|
|
|
|
|
math::jitkernel::KernelPool::Instance() \
|
|
|
|
|
.template Get<math::jitkernel::GRUKernel<T>, \
|
|
|
|
|
const math::jitkernel::gru_attr_t&>(attr); \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D2 = D * 2; \
|
|
|
|
|
const jit::gru_attr_t attr( \
|
|
|
|
|
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
|
|
|
|
|
jit::gru_t one_step; \
|
|
|
|
|
auto ComputeH1 = \
|
|
|
|
|
jit::Get<jit::gruh1, jit::GRUTuples, platform::CPUPlace>(attr); \
|
|
|
|
|
auto ComputeHtPart1 = \
|
|
|
|
|
jit::Get<jit::gruhtpart1, jit::GRUTuples, platform::CPUPlace>(attr); \
|
|
|
|
|
auto ComputeHtPart2 = \
|
|
|
|
|
jit::Get<jit::gruhtpart2, jit::GRUTuples, platform::CPUPlace>(attr); \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(place)
|
|
|
|
|
|
|
|
|
|
void SeqCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
@ -242,7 +244,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else {
|
|
|
|
|
one_step.gates = xx_data;
|
|
|
|
|
one_step.ht = hidden_out_data;
|
|
|
|
|
ker->ComputeH1(&one_step, &attr);
|
|
|
|
|
ComputeH1(&one_step, &attr);
|
|
|
|
|
prev_hidden_data = hidden_out_data;
|
|
|
|
|
tstart = 1;
|
|
|
|
|
move_step();
|
|
|
|
@ -255,12 +257,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
one_step.gates = xx_data;
|
|
|
|
|
one_step.ht_1 = prev_hidden_data;
|
|
|
|
|
one_step.ht = hidden_out_data;
|
|
|
|
|
ker->ComputeHtPart1(&one_step, &attr);
|
|
|
|
|
ComputeHtPart1(&one_step, &attr);
|
|
|
|
|
// gemm rt * Ws
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
|
|
|
|
|
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
|
|
|
|
|
xx_data + D2, D3);
|
|
|
|
|
ker->ComputeHtPart2(&one_step, &attr);
|
|
|
|
|
ComputeHtPart2(&one_step, &attr);
|
|
|
|
|
// save prev
|
|
|
|
|
prev_hidden_data = hidden_out_data;
|
|
|
|
|
move_step();
|
|
|
|
@ -324,7 +326,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < max_bs; ++i) {
|
|
|
|
|
one_step.gates = cur_in_data;
|
|
|
|
|
one_step.ht = cur_out_data;
|
|
|
|
|
ker->ComputeH1(&one_step, &attr);
|
|
|
|
|
ComputeH1(&one_step, &attr);
|
|
|
|
|
// add offset
|
|
|
|
|
cur_in_data += D3;
|
|
|
|
|
cur_out_data += D;
|
|
|
|
@ -352,7 +354,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
one_step.gates = cur_batched_data;
|
|
|
|
|
one_step.ht_1 = cur_prev_hidden_data;
|
|
|
|
|
one_step.ht = cur_out_data;
|
|
|
|
|
ker->ComputeHtPart1(&one_step, &attr);
|
|
|
|
|
ComputeHtPart1(&one_step, &attr);
|
|
|
|
|
|
|
|
|
|
cur_batched_data += D3;
|
|
|
|
|
cur_prev_hidden_data += D;
|
|
|
|
@ -370,7 +372,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
one_step.gates = cur_batched_data;
|
|
|
|
|
one_step.ht_1 = cur_prev_hidden_data;
|
|
|
|
|
one_step.ht = cur_out_data;
|
|
|
|
|
ker->ComputeHtPart2(&one_step, &attr);
|
|
|
|
|
ComputeHtPart2(&one_step, &attr);
|
|
|
|
|
cur_batched_data += D3;
|
|
|
|
|
cur_prev_hidden_data += D;
|
|
|
|
|
cur_out_data += D;
|
|
|
|
|