Make GRU Operator adapt to sequence2batch

mobile_baidu
guosheng 7 years ago
parent 83b48ebcb7
commit 53d8165f53

@ -66,7 +66,7 @@ class GRUKernel : public framework::OpKernel<T> {
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// to_batch(context.device_context(), *input, batch_gate, is_reverse);
to_batch(context.device_context(), *input, *batch_gate, is_reverse);
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0];
@ -172,8 +172,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad.set_lod(batch_hidden->lod());
// context.ShareLoD(framework::GradVarName("Hidden"),
// framework::GradVarName("Input"));
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad,
is_reverse, false);
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
is_reverse);
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);

Loading…
Cancel
Save