|
|
|
@ -66,7 +66,7 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|
|
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
|
|
|
|
// to_batch(context.device_context(), *input, batch_gate, is_reverse);
|
|
|
|
|
to_batch(context.device_context(), *input, *batch_gate, is_reverse);
|
|
|
|
|
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
|
|
|
|
|
|
|
|
|
|
int frame_size = hidden_dims[1];
|
|
|
|
|
int batch_size = hidden_dims[0];
|
|
|
|
@ -172,8 +172,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
batch_hidden_grad.set_lod(batch_hidden->lod());
|
|
|
|
|
// context.ShareLoD(framework::GradVarName("Hidden"),
|
|
|
|
|
// framework::GradVarName("Input"));
|
|
|
|
|
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad,
|
|
|
|
|
is_reverse, false);
|
|
|
|
|
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
|
|
|
|
|
is_reverse);
|
|
|
|
|
|
|
|
|
|
math::hl_gru_value<T> gru_value;
|
|
|
|
|
gru_value.gateWeight = const_cast<T*>(weight_data);
|
|
|
|
|