add runtime shape for fuse_emb_seq_pool_grad

test=develop
revert-16144-rnn_mem_opt
luotao1 6 years ago
parent 0024d3f4f0
commit 5d20954ac4

@ -121,6 +121,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
// runtime shape
d_table->set_height(table_dim[0]);
auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();

Loading…
Cancel
Save