|
|
|
@ -16,10 +16,9 @@ limitations under the License. */
|
|
|
|
|
#include <cstring> // for memcpy
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_info.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INIT_VEC_FUNC \
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_state; \
|
|
|
|
|
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
|
|
|
|
|
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
|
|
|
|
|
auto& act_state_str = ctx.Attr<std::string>("activation"); \
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) { \
|
|
|
|
|
math::VecActivations<T, platform::jit::avx> act_functor; \
|
|
|
|
|
act_gate = act_functor(act_gate_str); \
|
|
|
|
|
act_state = act_functor(act_state_str); \
|
|
|
|
|
cross = math::vec_cross<T, platform::jit::avx>; \
|
|
|
|
|
} else { \
|
|
|
|
|
math::VecActivations<T, platform::jit::isa_any> act_functor; \
|
|
|
|
|
act_gate = act_functor(act_gate_str); \
|
|
|
|
|
act_state = act_functor(act_state_str); \
|
|
|
|
|
cross = math::vec_cross<T, platform::jit::isa_any>; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INIT_BASE_INPUT_OUTPUT \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* xx = ctx.Output<LoDTensor>("XX"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse");
|
|
|
|
|
|
|
|
|
|
#define INIT_BASE_SIZES \
|
|
|
|
|
auto x_dims = x->dims(); /* T x M*/ \
|
|
|
|
|
auto wh_dims = wh->dims(); /* D x 3D*/ \
|
|
|
|
|
const int total_T = x_dims[0]; \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D3 = wh_dims[1]; \
|
|
|
|
|
const int D2 = D * 2;
|
|
|
|
|
#define INIT_BASE_DEFINES \
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X"); \
|
|
|
|
|
auto* wh = ctx.Input<Tensor>("WeightH"); \
|
|
|
|
|
auto* xx = ctx.Output<LoDTensor>("XX"); \
|
|
|
|
|
auto x_lod = x->lod(); \
|
|
|
|
|
auto x_dims = x->dims(); /* T x M*/ \
|
|
|
|
|
auto wh_dims = wh->dims(); /* D x 3D*/ \
|
|
|
|
|
const int total_T = x_dims[0]; \
|
|
|
|
|
const int D3 = wh_dims[1]
|
|
|
|
|
|
|
|
|
|
#define INIT_OTHER_DEFINES \
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0"); \
|
|
|
|
|
auto* wx = ctx.Input<Tensor>("WeightX"); \
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias"); \
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
|
|
|
|
|
const int M = x_dims[1]; \
|
|
|
|
|
const int D = wh_dims[0]; \
|
|
|
|
|
const int D2 = D * 2; \
|
|
|
|
|
const auto& ker = math::jitkernel::KernelPool::Instance() \
|
|
|
|
|
.template Get<math::jitkernel::GRUKernel<T>, \
|
|
|
|
|
const std::string&, const std::string&>( \
|
|
|
|
|
ctx.Attr<std::string>("gate_activation"), \
|
|
|
|
|
ctx.Attr<std::string>("activation"), D); \
|
|
|
|
|
const T* x_data = x->data<T>(); \
|
|
|
|
|
const T* wx_data = wx->data<T>(); \
|
|
|
|
|
const T* wh_data = wh->data<T>(); \
|
|
|
|
|
auto place = ctx.GetPlace(); \
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(place)
|
|
|
|
|
|
|
|
|
|
void SeqCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
INIT_BASE_SIZES
|
|
|
|
|
INIT_VEC_FUNC
|
|
|
|
|
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
|
INIT_BASE_DEFINES;
|
|
|
|
|
INIT_OTHER_DEFINES;
|
|
|
|
|
const int N = x_lod[0].size() - 1;
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : nullptr;
|
|
|
|
|
const T* wx_data = wx->data<T>();
|
|
|
|
|
const T* wh_data = wh->data<T>();
|
|
|
|
|
const T* wh_state_data = wh_data + D * D2;
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
T* hidden_out_data = hidden_out->mutable_data<T>(place);
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
|
|
|
|
|
xx_data,
|
|
|
|
@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (h0_data) {
|
|
|
|
|
prev_hidden_data = h0_data + bid * D;
|
|
|
|
|
} else {
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
// update gate
|
|
|
|
|
act_gate(D, xx_data, xx_data);
|
|
|
|
|
// state gate
|
|
|
|
|
act_state(D, xx_data + D2, xx_data + D2);
|
|
|
|
|
// out = a*b
|
|
|
|
|
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
|
|
|
|
|
// save prev
|
|
|
|
|
ker->ComputeH1(xx_data, hidden_out_data);
|
|
|
|
|
prev_hidden_data = hidden_out_data;
|
|
|
|
|
tstart = 1;
|
|
|
|
|
move_step();
|
|
|
|
@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
|
|
|
|
|
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
|
|
|
|
|
D3);
|
|
|
|
|
act_gate(D2, xx_data, xx_data);
|
|
|
|
|
// rt = rt*ht_1 inplace result
|
|
|
|
|
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
|
|
|
|
|
|
|
|
|
|
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
|
|
|
|
|
// gemm rt * Ws
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
|
|
|
|
|
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
|
|
|
|
|
xx_data + D2, D3);
|
|
|
|
|
act_state(D, xx_data + D2, xx_data + D2);
|
|
|
|
|
// out = zt*ht~ + (1-zt)*ht_1
|
|
|
|
|
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
|
|
|
|
|
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
|
|
|
|
|
// save prev
|
|
|
|
|
prev_hidden_data = hidden_out_data;
|
|
|
|
|
move_step();
|
|
|
|
@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* x = ctx.Input<LoDTensor>("X");
|
|
|
|
|
INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
INIT_BASE_SIZES
|
|
|
|
|
if (x->lod()[0].size() == 2) {
|
|
|
|
|
INIT_BASE_DEFINES;
|
|
|
|
|
if (x_lod[0].size() == 2) {
|
|
|
|
|
xx->Resize({total_T, D3});
|
|
|
|
|
SeqCompute(ctx);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
INIT_VEC_FUNC
|
|
|
|
|
|
|
|
|
|
INIT_OTHER_DEFINES;
|
|
|
|
|
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
|
|
|
|
|
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
|
|
|
|
|
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* wx_data = wx->data<T>();
|
|
|
|
|
const T* wh_data = wh->data<T>();
|
|
|
|
|
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
hidden_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
T* batched_input_data = batched_input->mutable_data<T>(place);
|
|
|
|
|
T* batched_out_data = batched_out->mutable_data<T>(place);
|
|
|
|
|
hidden_out->mutable_data<T>(place);
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
|
|
|
|
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
|
|
|
|
@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* prev_hidden_data = nullptr;
|
|
|
|
|
if (h0) {
|
|
|
|
|
// reorder h0
|
|
|
|
|
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
|
|
|
|
|
const T* h0_data = h0->data<T>();
|
|
|
|
|
prev_hidden_data = reordered_h0_data;
|
|
|
|
|
size_t sz = sizeof(T) * D;
|
|
|
|
@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* cur_out_data = batched_out_data;
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
for (int i = 0; i < max_bs; ++i) {
|
|
|
|
|
// update gate
|
|
|
|
|
act_gate(D, cur_in_data, cur_in_data);
|
|
|
|
|
// state gate
|
|
|
|
|
act_state(D, cur_in_data + D2, cur_in_data + D2);
|
|
|
|
|
// out = a*b
|
|
|
|
|
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
|
|
|
|
|
ker->ComputeH1(cur_in_data, cur_out_data);
|
|
|
|
|
// add offset
|
|
|
|
|
cur_in_data += D3;
|
|
|
|
|
cur_out_data += D;
|
|
|
|
@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* cur_out_data = batched_out_data;
|
|
|
|
|
T* cur_prev_hidden_data = prev_hidden_data;
|
|
|
|
|
for (int i = 0; i < cur_bs; ++i) {
|
|
|
|
|
act_gate(D2, cur_batched_data, cur_batched_data);
|
|
|
|
|
// rt = rt*ht_1 inplace result
|
|
|
|
|
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
|
|
|
|
|
|
|
|
|
|
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
|
|
|
|
|
cur_out_data);
|
|
|
|
|
cur_batched_data += D3;
|
|
|
|
|
cur_prev_hidden_data += D;
|
|
|
|
|
cur_out_data += D;
|
|
|
|
@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
cur_prev_hidden_data = prev_hidden_data;
|
|
|
|
|
for (int i = 0; i < cur_bs; ++i) {
|
|
|
|
|
// ht~ = act_state(...)
|
|
|
|
|
act_state(D, cur_batched_data + D2, cur_batched_data + D2);
|
|
|
|
|
// out = zt*ht~ + (1-zt)*ht_1
|
|
|
|
|
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
|
|
|
|
|
cur_out_data);
|
|
|
|
|
|
|
|
|
|
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
|
|
|
|
|
cur_out_data);
|
|
|
|
|
cur_batched_data += D3;
|
|
|
|
|
cur_prev_hidden_data += D;
|
|
|
|
|
cur_out_data += D;
|
|
|
|
@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
batched_out->set_lod(batched_lod);
|
|
|
|
|
to_seq(dev_ctx, *batched_out, hidden_out);
|
|
|
|
|
}
|
|
|
|
|
#undef INIT_VEC_FUNC
|
|
|
|
|
#undef INIT_BASE_SIZES
|
|
|
|
|
#undef INIT_BASE_INPUT_OUTPUT
|
|
|
|
|
#undef INIT_OTHER_DEFINES
|
|
|
|
|
#undef INIT_BASE_DEFINES
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|