|
|
|
@ -182,29 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int total_T = x_dims[0]; \
|
|
|
|
|
const int D3 = wh_dims[1]
|
|
|
|
|
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D2 = D * 2; \
|
|
|
|
|
const jit::gru_attr_t attr( \
|
|
|
|
|
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
|
|
|
|
|
jit::gru_t one_step; \
|
|
|
|
|
auto ComputeH1 = \
|
|
|
|
|
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
|
|
|
|
|
auto ComputeHtPart1 = \
|
|
|
|
|
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
|
|
|
|
|
auto ComputeHtPart2 = \
|
|
|
|
|
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D2 = D * 2; \
|
|
|
|
|
const jit::gru_attr_t attr( \
|
|
|
|
|
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
|
|
|
|
|
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
|
|
|
|
|
jit::gru_t one_step; \
|
|
|
|
|
auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \
|
|
|
|
|
platform::CPUPlace>::Cache() \
|
|
|
|
|
.At(attr); \
|
|
|
|
|
auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \
|
|
|
|
|
platform::CPUPlace>::Cache() \
|
|
|
|
|
.At(attr); \
|
|
|
|
|
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \
|
|
|
|
|
platform::CPUPlace>::Cache() \
|
|
|
|
|
.At(attr); \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(place)
|
|
|
|
|
|
|
|
|
|
void SeqCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|