|
|
|
@ -43,8 +43,6 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
for (auto& ids_dim : ids_dims) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
|
|
|
|
|
"The dimension of the 'Ids' tensor must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
|
|
|
|
|
"The last dimension of the 'Ids' tensor must be 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto lookup_tables =
|
|
|
|
@ -52,6 +50,8 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto height_sections =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
|
|
|
|
|
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
|
|
|
|
|
auto lookup_table_version =
|
|
|
|
|
ctx->Attrs().Get<std::string>("lookup_table_version");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
|
|
|
|
|
lookup_tables.size() == endpoints.size() &&
|
|
|
|
@ -62,7 +62,14 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto outputs_dims = std::vector<framework::DDim>();
|
|
|
|
|
|
|
|
|
|
for (auto& ids_dim : ids_dims) {
|
|
|
|
|
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]}));
|
|
|
|
|
if (lookup_table_version == "lookup_table") {
|
|
|
|
|
outputs_dims.push_back(
|
|
|
|
|
framework::make_ddim({ids_dim[0], table_dims[1]}));
|
|
|
|
|
} else if (lookup_table_version == "lookup_table_v2") {
|
|
|
|
|
outputs_dims.push_back(framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(ids_dim[0]), static_cast<int64_t>(ids_dim[1]),
|
|
|
|
|
static_cast<int64_t>(table_dims[1])}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputsDim("Outputs", outputs_dims);
|
|
|
|
@ -93,10 +100,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto height_sections =
|
|
|
|
|
context.Attr<std::vector<int64_t>>("height_sections");
|
|
|
|
|
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
|
|
|
|
|
auto lookup_table_version =
|
|
|
|
|
context.Attr<std::string>("lookup_table_version");
|
|
|
|
|
|
|
|
|
|
operators::distributed::prefetchs(
|
|
|
|
|
id_names, out_names, embedding_name, false, lookup_tables, endpoints,
|
|
|
|
|
height_sections, context, context.scope());
|
|
|
|
|
|
|
|
|
|
if (lookup_table_version == "lookup_table_v2") {
|
|
|
|
|
auto& scope = context.scope();
|
|
|
|
|
auto emb_dim =
|
|
|
|
|
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < id_names.size(); ++i) {
|
|
|
|
|
auto* id_var = scope.FindVar(id_names[i]);
|
|
|
|
|
auto* out_var = scope.FindVar(out_names[i]);
|
|
|
|
|
auto* id_tensor = id_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto* out_tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
auto id_dims = id_tensor->dims();
|
|
|
|
|
out_tensor->Resize(framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
|
|
|
|
|
static_cast<int64_t>(emb_dim)}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -134,6 +161,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
|
|
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
|
|
|
|
|
|
|
|
|
|
AddAttr<std::string>(
|
|
|
|
|
"lookup_table_version",
|
|
|
|
|
"(string, default lookup_table) "
|
|
|
|
|
"To distinguish between different versions of embedding OP")
|
|
|
|
|
.SetDefault(std::string("lookup_table"));
|
|
|
|
|
|
|
|
|
|
AddAttr<int64_t>("padding_idx",
|
|
|
|
|
"(int64, default -1) "
|
|
|
|
|
"If the value is -1, it makes no effect to lookup. "
|
|
|
|
|