|
|
|
@ -55,12 +55,16 @@ class LookupSparseTableOp : public framework::OperatorBase {
|
|
|
|
|
"The type of Out var should be LodTensor.");
|
|
|
|
|
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
|
|
|
|
|
"The type of W var should be SelectedRows.");
|
|
|
|
|
PADDLE_ENFORCE(ids_var->IsType<framework::SelectedRows>(),
|
|
|
|
|
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
|
|
|
|
|
"The type of Ids var should be SelectedRows.");
|
|
|
|
|
auto &ids_t = ids_var->Get<framework::SelectedRows>();
|
|
|
|
|
auto &ids_t = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto out_t = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto w_t = w_var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
auto keys = ids_t.rows();
|
|
|
|
|
std::vector<int64_t> keys;
|
|
|
|
|
keys.resize(ids_t.numel());
|
|
|
|
|
for (size_t i = 0; i < ids_t.numel(); ++i) {
|
|
|
|
|
keys[i] = ids_t.data<int64_t>()[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(Yancey1989): support CUDA Place for the sparse table
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
@ -68,7 +72,6 @@ class LookupSparseTableOp : public framework::OperatorBase {
|
|
|
|
|
out_shape[0] = keys.size();
|
|
|
|
|
out_t->Resize(out_shape);
|
|
|
|
|
out_t->mutable_data(cpu, w_t->value().type());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
|
|
|
|
|
framework::proto::VarType::FP32,
|
|
|
|
|
"The sparse table only support FP32");
|
|
|
|
|