|
|
|
@ -120,12 +120,22 @@ template <typename T>
|
|
|
|
|
class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
|
auto *table_var = context.InputVar("W");
|
|
|
|
|
DDim table_dim;
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
table_dim = context.Input<LoDTensor>("W")->dims();
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
auto *table_t = context.Input<SelectedRows>("W");
|
|
|
|
|
table_dim = table_t->value().dims();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("table only support LoDTensor and SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|
// Since paddings are not trainable and fixed in forward, the gradient of
|
|
|
|
|
// paddings makes no sense and we don't deal with it in backward.
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
auto *ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto *table = context.Input<LoDTensor>("W");
|
|
|
|
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
|
|
|
|
|
|
|
|
@ -140,10 +150,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
d_table->set_rows(new_rows);
|
|
|
|
|
|
|
|
|
|
auto *d_table_value = d_table->mutable_value();
|
|
|
|
|
d_table_value->Resize({ids_dim[0], table->dims()[1]});
|
|
|
|
|
d_table_value->Resize({ids_dim[0], table_dim[1]});
|
|
|
|
|
d_table_value->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
d_table->set_height(table->dims()[0]);
|
|
|
|
|
d_table->set_height(table_dim[0]);
|
|
|
|
|
|
|
|
|
|
auto *d_output_data = d_output->data<T>();
|
|
|
|
|
auto *d_table_data = d_table_value->data<T>();
|
|
|
|
@ -154,12 +164,11 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
|
|
|
|
|
auto *table = context.Input<LoDTensor>("W");
|
|
|
|
|
|
|
|
|
|
auto *ids_data = ids->data<int64_t>();
|
|
|
|
|
auto ids_dim = ids->dims();
|
|
|
|
|
|
|
|
|
|
int N = table->dims()[0];
|
|
|
|
|
int N = table_dim[0];
|
|
|
|
|
int D = d_output->dims()[1];
|
|
|
|
|
|
|
|
|
|
auto *d_output_data = d_output->data<T>();
|
|
|
|
|