|
|
|
@ -24,8 +24,17 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
inline void ReorderInitState(const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Tensor& src, const size_t* index,
|
|
|
|
|
framework::Tensor* dst, bool indexed_src) {
|
|
|
|
|
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
|
|
|
|
|
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
|
|
|
|
|
row_shuffle(ctx, src, index, *dst, indexed_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
@ -33,7 +42,6 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& context) const {
|
|
|
|
|
auto* input = context.Input<LoDTensor>("Input");
|
|
|
|
|
auto* h0 = context.Input<Tensor>("H0");
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : nullptr;
|
|
|
|
|
auto* weight = context.Input<Tensor>("Weight");
|
|
|
|
|
const T* weight_data = weight->data<T>();
|
|
|
|
|
auto* bias = context.Input<Tensor>("Bias");
|
|
|
|
@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
gru_value.gateWeight = const_cast<T*>(weight_data);
|
|
|
|
|
gru_value.stateWeight =
|
|
|
|
|
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
|
|
|
|
|
gru_value.prevOutValue = const_cast<T*>(h0_data);
|
|
|
|
|
Tensor ordered_h0;
|
|
|
|
|
const size_t* order = batch_gate->lod()[2].data();
|
|
|
|
|
if (h0) {
|
|
|
|
|
// Since the batch computing for GRU reorders the input sequences
|
|
|
|
|
// according to their length. The initialized cell state also needs
|
|
|
|
|
// to reorder.
|
|
|
|
|
ReorderInitState<Place, T>(context.device_context(), *h0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
gru_value.prevOutValue = ordered_h0.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
gru_value.prevOutValue = nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto batch_starts = batch_gate->lod()[0];
|
|
|
|
|
size_t num_batch = batch_starts.size() - 1;
|
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
|
|
@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& context) const {
|
|
|
|
|
auto* h0 = context.Input<Tensor>("H0");
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : nullptr;
|
|
|
|
|
auto* weight = context.Input<Tensor>("Weight");
|
|
|
|
|
const T* weight_data = weight->data<T>();
|
|
|
|
|
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
|
|
|
|
@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
|
|
|
|
|
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
|
|
|
|
|
|
|
|
|
|
Tensor ordered_h0, ordered_h0_grad;
|
|
|
|
|
const size_t* order = batch_gate->lod()[2].data();
|
|
|
|
|
if (h0) {
|
|
|
|
|
ReorderInitState<Place, T>(context.device_context(), *h0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
}
|
|
|
|
|
if (h0_grad) {
|
|
|
|
|
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
|
|
|
|
|
zero(context.device_context(), &ordered_h0_grad, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|
|
batch_hidden_grad.set_lod(batch_hidden->lod());
|
|
|
|
|
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
|
|
|
|
@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
batch_reset_hidden_prev_grad.Slice(bstart, bend);
|
|
|
|
|
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
|
|
|
|
|
if (n == 0) {
|
|
|
|
|
gru_value.prevOutValue = const_cast<T*>(h0_data);
|
|
|
|
|
if (h0_grad) {
|
|
|
|
|
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
zero(dev_ctx, h0_grad, static_cast<T>(0.0));
|
|
|
|
|
gru_grad.prevOutGrad = h0_grad_data;
|
|
|
|
|
} else {
|
|
|
|
|
gru_grad.prevOutGrad = nullptr;
|
|
|
|
|
}
|
|
|
|
|
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
|
|
|
|
|
gru_grad.prevOutGrad =
|
|
|
|
|
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
|
|
|
|
|
} else {
|
|
|
|
|
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
|
|
|
|
@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::ColwiseSum<Place, T> col_sum;
|
|
|
|
|
col_sum(dev_ctx, batch_gate_grad, bias_grad);
|
|
|
|
|
}
|
|
|
|
|
if (h0 && h0_grad) {
|
|
|
|
|
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
|
|
|
|
|
order, h0_grad, false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|