From 53d8165f5379680396fff750184ead563d754d24 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 1 Nov 2017 11:24:42 +0800 Subject: [PATCH] Make GRU Operator adapt to sequence2batch --- paddle/operators/gru_op.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index a04dd8d05f..2c9aa76242 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -66,7 +66,7 @@ class GRUKernel : public framework::OpKernel { bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; // to_batch(context.device_context(), *input, batch_gate, is_reverse); - to_batch(context.device_context(), *input, *batch_gate, is_reverse); + to_batch(context.device_context(), *input, *batch_gate, true, is_reverse); int frame_size = hidden_dims[1]; int batch_size = hidden_dims[0]; @@ -172,8 +172,8 @@ class GRUGradKernel : public framework::OpKernel { batch_hidden_grad.set_lod(batch_hidden->lod()); // context.ShareLoD(framework::GradVarName("Hidden"), // framework::GradVarName("Input")); - to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, - is_reverse, false); + to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false, + is_reverse); math::hl_gru_value gru_value; gru_value.gateWeight = const_cast(weight_data);