|
|
|
@ -19,22 +19,22 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using SelectedRows = framework::SelectedRows;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto table_t = context.Input<Tensor>("W"); // float tensor
|
|
|
|
|
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
|
|
|
|
|
auto output_t = context.Output<Tensor>("Out"); // float tensor
|
|
|
|
|
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
|
|
|
|
|
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
|
|
|
|
|
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
|
|
|
|
|
|
|
|
|
|
int N = table_t->dims()[0];
|
|
|
|
|
int D = table_t->dims()[1];
|
|
|
|
|
auto ids = ids_t->data<int64_t>();
|
|
|
|
|
auto table = table_t->data<T>();
|
|
|
|
|
auto output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* ids = ids_t->data<int64_t>();
|
|
|
|
|
auto* table = table_t->data<T>();
|
|
|
|
|
auto* output = output_t->mutable_data<T>(context.GetPlace());
|
|
|
|
|
for (int64_t i = 0; i < ids_t->numel(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(ids[i], N);
|
|
|
|
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
|
|
|
@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
auto* ids = context.Input<Tensor>("Ids");
|
|
|
|
|
auto* table = context.Input<Tensor>("W");
|
|
|
|
|
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* table = context.Input<LoDTensor>("W");
|
|
|
|
|
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
|
|
|
|
|
|
|
|
|
auto* ids_data = ids->data<int64_t>();
|
|
|
|
@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
|
|
|
|
|
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
|
|
|
|
|
} else {
|
|
|
|
|
auto* ids = context.Input<Tensor>("Ids");
|
|
|
|
|
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_table = context.Output<Tensor>(framework::GradVarName("W"));
|
|
|
|
|
auto* table = context.Input<Tensor>("W");
|
|
|
|
|
auto* ids = context.Input<LoDTensor>("Ids");
|
|
|
|
|
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
|
|
|
|
|
auto* table = context.Input<LoDTensor>("W");
|
|
|
|
|
|
|
|
|
|
auto* ids_data = ids->data<int64_t>();
|
|
|
|
|
auto ids_dim = ids->dims();
|
|
|
|
|