|
|
|
@ -47,43 +47,34 @@ namespace distributed {
|
|
|
|
|
|
|
|
|
|
enum Mode { training, infer };
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline bool entry(const int count, const T threshold);
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool entry<std::string>(const int count, const std::string threshold) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool entry<int>(const int count, const int threshold) {
|
|
|
|
|
return count >= threshold;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool entry<float>(const int count, const float threshold) {
|
|
|
|
|
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
|
|
|
|
|
return uniform.GetValue() >= threshold;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct VALUE {
|
|
|
|
|
explicit VALUE(size_t length)
|
|
|
|
|
: length_(length),
|
|
|
|
|
count_(1),
|
|
|
|
|
count_(0),
|
|
|
|
|
unseen_days_(0),
|
|
|
|
|
seen_after_last_save_(true),
|
|
|
|
|
is_entry_(true) {
|
|
|
|
|
need_save_(false),
|
|
|
|
|
is_entry_(false) {
|
|
|
|
|
data_.resize(length);
|
|
|
|
|
memset(data_.data(), 0, sizeof(float) * length);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t length_;
|
|
|
|
|
std::vector<float> data_;
|
|
|
|
|
int count_;
|
|
|
|
|
int unseen_days_;
|
|
|
|
|
bool seen_after_last_save_;
|
|
|
|
|
bool is_entry_;
|
|
|
|
|
int unseen_days_; // use to check knock-out
|
|
|
|
|
bool need_save_; // whether need to save
|
|
|
|
|
bool is_entry_; // whether knock-in
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) {
|
|
|
|
|
return value->count_ >= threshold;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) {
|
|
|
|
|
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
|
|
|
|
|
return uniform.GetValue() >= threshold;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class ValueBlock {
|
|
|
|
|
public:
|
|
|
|
|
explicit ValueBlock(const std::vector<std::string> &value_names,
|
|
|
|
@ -102,21 +93,21 @@ class ValueBlock {
|
|
|
|
|
|
|
|
|
|
// for Entry
|
|
|
|
|
{
|
|
|
|
|
if (entry_attr == "none") {
|
|
|
|
|
has_entry_ = false;
|
|
|
|
|
entry_func_ =
|
|
|
|
|
std::bind(entry<std::string>, std::placeholders::_1, "none");
|
|
|
|
|
} else {
|
|
|
|
|
has_entry_ = true;
|
|
|
|
|
auto slices = string::split_string<std::string>(entry_attr, "&");
|
|
|
|
|
if (slices[0] == "count_filter") {
|
|
|
|
|
if (slices[0] == "none") {
|
|
|
|
|
entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0);
|
|
|
|
|
} else if (slices[0] == "count_filter") {
|
|
|
|
|
int threshold = std::stoi(slices[1]);
|
|
|
|
|
entry_func_ = std::bind(entry<int>, std::placeholders::_1, threshold);
|
|
|
|
|
entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold);
|
|
|
|
|
} else if (slices[0] == "probability") {
|
|
|
|
|
float threshold = std::stof(slices[1]);
|
|
|
|
|
entry_func_ =
|
|
|
|
|
std::bind(entry<float>, std::placeholders::_1, threshold);
|
|
|
|
|
}
|
|
|
|
|
std::bind(&probility_entry, std::placeholders::_1, threshold);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Not supported Entry Type : %s, Only support [count_filter, "
|
|
|
|
|
"probability]",
|
|
|
|
|
slices[0]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -147,58 +138,87 @@ class ValueBlock {
|
|
|
|
|
|
|
|
|
|
~ValueBlock() {}
|
|
|
|
|
|
|
|
|
|
float *Init(const uint64_t &id) {
|
|
|
|
|
auto value = std::make_shared<VALUE>(value_length_);
|
|
|
|
|
for (int x = 0; x < value_names_.size(); ++x) {
|
|
|
|
|
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
|
|
|
|
|
value_dims_[x]);
|
|
|
|
|
}
|
|
|
|
|
values_[id] = value;
|
|
|
|
|
return value->data_.data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<float *> Get(const uint64_t &id,
|
|
|
|
|
const std::vector<std::string> &value_names) {
|
|
|
|
|
const std::vector<std::string> &value_names,
|
|
|
|
|
const std::vector<int> &value_dims) {
|
|
|
|
|
auto pts = std::vector<float *>();
|
|
|
|
|
pts.reserve(value_names.size());
|
|
|
|
|
auto &values = values_.at(id);
|
|
|
|
|
for (int i = 0; i < static_cast<int>(value_names.size()); i++) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
value_dims[i], value_dims_[i],
|
|
|
|
|
platform::errors::InvalidArgument("value dims is not match"));
|
|
|
|
|
pts.push_back(values->data_.data() +
|
|
|
|
|
value_offsets_.at(value_idx_.at(value_names[i])));
|
|
|
|
|
}
|
|
|
|
|
return pts;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float *Get(const uint64_t &id) {
|
|
|
|
|
auto pts = std::vector<std::vector<float> *>();
|
|
|
|
|
auto &values = values_.at(id);
|
|
|
|
|
// pull
|
|
|
|
|
float *Init(const uint64_t &id, const bool with_update = true) {
|
|
|
|
|
if (!Has(id)) {
|
|
|
|
|
values_[id] = std::make_shared<VALUE>(value_length_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return values->data_.data();
|
|
|
|
|
auto &value = values_.at(id);
|
|
|
|
|
|
|
|
|
|
if (with_update) {
|
|
|
|
|
AttrUpdate(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float *InitFromInitializer(const uint64_t &id) {
|
|
|
|
|
if (Has(id)) {
|
|
|
|
|
if (has_entry_) {
|
|
|
|
|
Update(id);
|
|
|
|
|
return value->data_.data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AttrUpdate(std::shared_ptr<VALUE> value) {
|
|
|
|
|
// update state
|
|
|
|
|
value->unseen_days_ = 0;
|
|
|
|
|
++value->count_;
|
|
|
|
|
|
|
|
|
|
if (!value->is_entry_) {
|
|
|
|
|
value->is_entry_ = entry_func_(value);
|
|
|
|
|
if (value->is_entry_) {
|
|
|
|
|
// initialize
|
|
|
|
|
for (int x = 0; x < value_names_.size(); ++x) {
|
|
|
|
|
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
|
|
|
|
|
value_dims_[x]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return Get(id);
|
|
|
|
|
}
|
|
|
|
|
return Init(id);
|
|
|
|
|
|
|
|
|
|
value->need_save_ = true;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// dont jude if (has(id))
|
|
|
|
|
float *Get(const uint64_t &id) {
|
|
|
|
|
auto &value = values_.at(id);
|
|
|
|
|
return value->data_.data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// for load, to reset count, unseen_days
|
|
|
|
|
std::shared_ptr<VALUE> GetValue(const uint64_t &id) { return values_.at(id); }
|
|
|
|
|
|
|
|
|
|
bool GetEntry(const uint64_t &id) {
|
|
|
|
|
auto value = values_.at(id);
|
|
|
|
|
auto &value = values_.at(id);
|
|
|
|
|
return value->is_entry_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Update(const uint64_t id) {
|
|
|
|
|
auto value = values_.at(id);
|
|
|
|
|
value->unseen_days_ = 0;
|
|
|
|
|
auto count = ++value->count_;
|
|
|
|
|
void SetEntry(const uint64_t &id, const bool state) {
|
|
|
|
|
auto &value = values_.at(id);
|
|
|
|
|
value->is_entry_ = state;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!value->is_entry_) {
|
|
|
|
|
value->is_entry_ = entry_func_(count);
|
|
|
|
|
void Shrink(const int threshold) {
|
|
|
|
|
for (auto iter = values_.begin(); iter != values_.end();) {
|
|
|
|
|
auto &value = iter->second;
|
|
|
|
|
value->unseen_days_++;
|
|
|
|
|
if (value->unseen_days_ >= threshold) {
|
|
|
|
|
iter = values_.erase(iter);
|
|
|
|
|
} else {
|
|
|
|
|
++iter;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -221,8 +241,7 @@ class ValueBlock {
|
|
|
|
|
const std::vector<int> &value_offsets_;
|
|
|
|
|
const std::unordered_map<std::string, int> &value_idx_;
|
|
|
|
|
|
|
|
|
|
bool has_entry_ = false;
|
|
|
|
|
std::function<bool(uint64_t)> entry_func_;
|
|
|
|
|
std::function<bool(std::shared_ptr<VALUE>)> entry_func_;
|
|
|
|
|
std::vector<std::shared_ptr<Initializer>> initializers_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|