|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/lstm_op.h"
|
|
|
|
|
#include "paddle/operators/math/gru_compute.h"
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/operators/math/sequence2batch.h"
|
|
|
|
@ -25,6 +24,18 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
inline void ReorderInitState(const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Tensor& src, const size_t* index,
|
|
|
|
|
framework::Tensor* dst, bool indexed_src) {
|
|
|
|
|
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
|
|
|
|
|
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
|
|
|
|
|
row_shuffle(ctx, src, index, *dst, indexed_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
batch_reset_hidden_prev_grad.Slice(bstart, bend);
|
|
|
|
|
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
|
|
|
|
|
if (n == 0) {
|
|
|
|
|
if (h0) {
|
|
|
|
|
gru_value.prevOutValue = ordered_h0.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
gru_value.prevOutValue = nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (h0 && h0_grad) {
|
|
|
|
|
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
gru_grad.prevOutGrad = nullptr;
|
|
|
|
|
}
|
|
|
|
|
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
|
|
|
|
|
gru_grad.prevOutGrad =
|
|
|
|
|
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
|
|
|
|
|
} else {
|
|
|
|
|
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
|
|
|
|
|