lookuptable support SelectedRows as table parameter

helinwang-patch-1
qiaolongfei 8 years ago
parent 997e9a1fd2
commit 3b3d210c3e

@ -10,6 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
@ -52,7 +55,7 @@ class SelectedRows {
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simplely concated when adding together. Until a
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr};

@ -84,7 +84,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1);
.SetDefault(kNoPadding);
AddComment(R"DOC(
Lookup Table Operator.

@ -14,6 +14,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
@ -25,14 +28,35 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static const int64_t kNoPadding = -1;
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return std::distance(rows.begin(), it);
}
template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto* table_t = context.Input<LoDTensor>("W");
auto *table_var = context.InputVar("W");
auto *ids_var = context.InputVar("Ids");
Tensor *output_t = context.Output<Tensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW("table only support LoDTensor and SelectedRows");
}
int64_t *ids;
int64_t ids_numel;
@ -49,32 +73,43 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t *>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]});
output_t->Resize({ids_numel, table_dim[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
int N = table_t->dims()[0];
int D = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1) {
for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
} else {
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
auto id_index = getIndex(table_t.rows(), ids[i]);
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
}
}
}

Loading…
Cancel
Save