|
|
|
@ -49,6 +49,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
|
|
|
|
|
unsigned int seed = static_cast<unsigned int>(Attr<int>("seed"));
|
|
|
|
|
float min = Attr<float>("min");
|
|
|
|
|
float max = Attr<float>("max");
|
|
|
|
|
bool auto_grown_table = Attr<bool>("auto_grown_table");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
|
|
|
|
|
"The type of Out var should be LodTensor.");
|
|
|
|
@ -71,8 +72,11 @@ class LookupSparseTableOp : public framework::OperatorBase {
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
|
|
|
|
|
framework::proto::VarType::FP32,
|
|
|
|
|
"The sparse table only support FP32");
|
|
|
|
|
|
|
|
|
|
auto non_keys_pair = w_t->Get(keys, out_t);
|
|
|
|
|
if (!auto_grown_table) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast<size_t>(0),
|
|
|
|
|
"there is some keys does exists in the sparse table.");
|
|
|
|
|
}
|
|
|
|
|
auto value_shape = w_t->value().dims();
|
|
|
|
|
value_shape[0] = 1;
|
|
|
|
|
for (const auto &it : non_keys_pair) {
|
|
|
|
@ -130,6 +134,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"Note that if seed is not 0, this operator will always "
|
|
|
|
|
"generate the same random numbers every time.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<bool>("auto_grown_table",
|
|
|
|
|
"(bool default false)"
|
|
|
|
|
"Whether create new value if for nonexistent key.")
|
|
|
|
|
.SetDefault(true);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Lookup Sprase Tablel Operator.
|
|
|
|
|
|
|
|
|
|