|
|
|
@ -121,6 +121,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
|
|
|
|
// runtime shape
|
|
|
|
|
d_table->set_height(table_dim[0]);
|
|
|
|
|
|
|
|
|
|
auto *ids_data = ids->data<int64_t>();
|
|
|
|
|
int64_t ids_num = ids->numel();
|
|
|
|
|