|
|
|
@ -51,26 +51,16 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* hidden = context.Output<LoDTensor>("Hidden");
|
|
|
|
|
hidden->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
// context.ShareLoD("Input", "Gate");
|
|
|
|
|
// context.ShareLoD("Input", "ResetHiddenPrev");
|
|
|
|
|
context.ShareLoD("Input", "Hidden");
|
|
|
|
|
|
|
|
|
|
// auto gate_dims = gate->dims();
|
|
|
|
|
auto hidden_dims = hidden->dims();
|
|
|
|
|
|
|
|
|
|
// LoDTensor batch_gate, batch_reset_hidden_prev, batch_hidden;
|
|
|
|
|
// batch_gate.mutable_data<T>(gate_dims, context.GetPlace());
|
|
|
|
|
// batch_reset_hidden_prev.mutable_data<T>(hidden_dims, context.GetPlace());
|
|
|
|
|
// batch_hidden.mutable_data<T>(hidden_dims, context.GetPlace());
|
|
|
|
|
|
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|
|
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
|
|
|
|
// to_batch(context.device_context(), *input, batch_gate, is_reverse);
|
|
|
|
|
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
|
|
|
|
|
|
|
|
|
|
int frame_size = hidden_dims[1];
|
|
|
|
|
int batch_size = hidden_dims[0];
|
|
|
|
|
// auto g = EigenMatrix<T>::From(batch_gate);
|
|
|
|
|
auto g = EigenMatrix<T>::From(*batch_gate);
|
|
|
|
|
auto place = context.GetEigenDevice<Place>();
|
|
|
|
|
if (bias) {
|
|
|
|
@ -85,20 +75,13 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
gru_value.stateWeight =
|
|
|
|
|
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
|
|
|
|
|
gru_value.prevOutValue = const_cast<T*>(h0_data);
|
|
|
|
|
// auto batch_starts = batch_gate.lod()[0];
|
|
|
|
|
auto batch_starts = batch_gate->lod()[0];
|
|
|
|
|
// for (auto i = batch_gate->lod()[1].begin(); i !=
|
|
|
|
|
// batch_gate->lod()[1].end(); ++i)
|
|
|
|
|
// std::cout << static_cast<int>(*i) << ' ';
|
|
|
|
|
size_t num_batch = batch_starts.size() - 1;
|
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
|
|
|
|
|
|
// Tensor gate_t = batch_gate.Slice(bstart, bend);
|
|
|
|
|
// Tensor reset_hidden_prev_t = batch_reset_hidden_prev.Slice(bstart,
|
|
|
|
|
// bend);
|
|
|
|
|
Tensor gate_t = batch_gate->Slice(bstart, bend);
|
|
|
|
|
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
|
|
|
|
|
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
|
|
|
|
@ -113,13 +96,6 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::Batch2LoDTensorFunctor<Place, T> to_seq;
|
|
|
|
|
// batch_gate.set_lod(batch_gate.lod());
|
|
|
|
|
// to_seq(context.device_context(), batch_gate, *gate);
|
|
|
|
|
// batch_reset_hidden_prev.set_lod(batch_gate.lod());
|
|
|
|
|
// to_seq(context.device_context(), batch_reset_hidden_prev,
|
|
|
|
|
// *reset_hidden_prev);
|
|
|
|
|
// batch_hidden.set_lod(batch_gate.lod());
|
|
|
|
|
// to_seq(context.device_context(), batch_hidden, *hidden);
|
|
|
|
|
batch_hidden->set_lod(batch_gate->lod());
|
|
|
|
|
to_seq(context.device_context(), *batch_hidden, *hidden);
|
|
|
|
|
}
|
|
|
|
@ -167,11 +143,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
zero(context.device_context(), &batch_reset_hidden_prev_grad,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
|
|
|
|
|
// batch_hidden.set_lod(batch_gate->lod());
|
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|
|
batch_hidden_grad.set_lod(batch_hidden->lod());
|
|
|
|
|
// context.ShareLoD(framework::GradVarName("Hidden"),
|
|
|
|
|
// framework::GradVarName("Input"));
|
|
|
|
|
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
|
|
|
|
|
is_reverse);
|
|
|
|
|
|
|
|
|
|