|
|
|
@ -119,9 +119,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto *d_table_value = d_table->mutable_value();
|
|
|
|
|
d_table_value->Resize({ids_num, table_dim[1]});
|
|
|
|
|
// FIXME(minqiyang):
|
|
|
|
|
// memory optimization will NOT reuse Tensor with SelectedRows
|
|
|
|
|
// so we could just share the tensor here directly.
|
|
|
|
|
d_table_value->ShareDataWith(*d_output);
|
|
|
|
|
// However, the InferVarType method will infer the output SelectedRows
|
|
|
|
|
// to Tensor sometimes, which is a bug, so we will add an attribute
|
|
|
|
|
// here to indicate the inplace and remove this attribute after
|
|
|
|
|
// the InferVarType's bug was fixed
|
|
|
|
|
bool grad_inplace = context.Attr<bool>("grad_inplace");
|
|
|
|
|
if (grad_inplace) {
|
|
|
|
|
d_table_value->ShareDataWith(*d_output);
|
|
|
|
|
} else {
|
|
|
|
|
d_table_value->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
d_table->set_height(table_dim[0]);
|
|
|
|
|
|
|
|
|
|
auto *d_output_data = d_output->data<T>();
|
|
|
|
|
auto *d_table_data = d_table_value->data<T>();
|
|
|
|
|
|
|
|
|
|
auto d_output_dims = d_output->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
d_table_value->dims(),
|
|
|
|
|
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
|
|
|
|
|
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto *ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|