Enhance lookup_table_op to support padding_idx

fix-profile-doc-typo
guosheng 8 years ago
parent f086ebb8b9
commit 9247aee7e4

@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update")
.SetDefault(false);
AddAttr<int64_t>(
"padding_idx",
"(int64_t, default -1) "
" If given, pads the output with zeros whenever it encounters "
"the index.")
.SetDefault(-1);
AddComment(R"DOC(
Lookup Table Operator.

@ -32,6 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0];
int D = table_t->dims()[1];
@ -39,11 +40,15 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
}
}
};
template <typename T>
@ -51,6 +56,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
@ -63,6 +70,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
framework::Vector<int64_t> new_rows;
new_rows.reserve(ids_dim[0]);
for (int64_t i = 0; i < ids_dim[0]; i++) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
new_rows.push_back(ids_data[i]);
}
d_table->set_rows(new_rows);
@ -96,6 +106,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
memset(d_table_data, 0, d_table->numel() * sizeof(T));
for (int64_t i = 0; i < ids->numel(); ++i) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0);
for (int j = 0; j < D; ++j) {

Loading…
Cancel
Save