Support AucRunner in PaddleBox (#22884)

* Support AucRunner in PaddleBox
* update some code style
v1.8
hutuxian 5 years ago committed by GitHub
parent c417f991c1
commit e6b87b3193
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,44 +41,44 @@ namespace paddle {
namespace framework {
void RecordCandidateList::ReSize(size_t length) {
_mutex.lock();
_capacity = length;
CHECK(_capacity > 0); // NOLINT
_candidate_list.clear();
_candidate_list.resize(_capacity);
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
mutex_.lock();
capacity_ = length;
CHECK(capacity_ > 0); // NOLINT
candidate_list_.clear();
candidate_list_.resize(capacity_);
full_ = false;
cur_size_ = 0;
total_size_ = 0;
mutex_.unlock();
}
void RecordCandidateList::ReInit() {
_mutex.lock();
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
mutex_.lock();
full_ = false;
cur_size_ = 0;
total_size_ = 0;
mutex_.unlock();
}
void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) {
_mutex.lock();
mutex_.lock();
size_t index = 0;
++_total_size;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) {
_candidate_list[_cur_size++] = record;
_full = (_cur_size == _capacity);
if (!full_) {
candidate_list_[cur_size_++] = record;
full_ = (cur_size_ == capacity_);
} else {
CHECK(_cur_size == _capacity);
index = fleet_ptr->LocalRandomEngine()() % _total_size;
if (index < _capacity) {
_candidate_list[index] = record;
CHECK(cur_size_ == capacity_);
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_[index] = record;
}
}
index = fleet_ptr->LocalRandomEngine()() % _cur_size;
*result = _candidate_list[index];
_mutex.unlock();
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
*result = candidate_list_[index];
mutex_.unlock();
}
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
int PaddleBoxDataFeed::GetCurrentPhase() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
return box_ptr->PassFlag(); // join: 1, update: 0
if (box_ptr->Mode() == 1) { // For AucRunner
return 1;
} else {
return box_ptr->Phase();
}
#else
LOG(WARNING) << "It should be complied with BOX_PS...";
return current_phase_;

@ -27,6 +27,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
std::unordered_multimap<uint16_t, FeatureKey> feas_;
size_t shadow_index_ = -1; // Optimization for Reservoir Sample
RecordCandidate() {}
RecordCandidate(const Record& rec,
const std::unordered_set<uint16_t>& slot_index_to_replace) {
for (const auto& fea : rec.uint64_feasigns_) {
if (slot_index_to_replace.find(fea.slot()) !=
slot_index_to_replace.end()) {
feas_.insert({fea.slot(), fea.sign()});
}
}
}
RecordCandidate& operator=(const Record& rec) {
feas.clear();
feas_.clear();
ins_id_ = rec.ins_id_;
for (auto& fea : rec.uint64_feasigns_) {
feas.insert({fea.slot(), fea.sign()});
feas_.insert({fea.slot(), fea.sign()});
}
return *this;
}
@ -499,22 +513,67 @@ struct RecordCandidate {
class RecordCandidateList {
public:
RecordCandidateList() = default;
RecordCandidateList(const RecordCandidateList&) = delete;
RecordCandidateList& operator=(const RecordCandidateList&) = delete;
RecordCandidateList(const RecordCandidateList&) {}
size_t Size() { return cur_size_; }
void ReSize(size_t length);
void ReInit();
void ReInitPass() {
for (size_t i = 0; i < cur_size_; ++i) {
if (candidate_list_[i].shadow_index_ != i) {
candidate_list_[i].ins_id_ =
candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
candidate_list_[i].feas_.swap(
candidate_list_[candidate_list_[i].shadow_index_].feas_);
candidate_list_[i].shadow_index_ = i;
}
}
candidate_list_.resize(cur_size_);
}
void AddAndGet(const Record& record, RecordCandidate* result);
void AddAndGet(const Record& record, size_t& index_result) { // NOLINT
// std::unique_lock<std::mutex> lock(mutex_);
size_t index = 0;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!full_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_.back().shadow_index_ = cur_size_;
++cur_size_;
full_ = (cur_size_ == capacity_);
} else {
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
}
}
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
index_result = candidate_list_[index].shadow_index_;
}
const RecordCandidate& Get(size_t index) const {
PADDLE_ENFORCE_LT(
index, candidate_list_.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in candidate_list[%lu].",
index, candidate_list_.size()));
return candidate_list_[index];
}
void SetSlotIndexToReplace(
const std::unordered_set<uint16_t>& slot_index_to_replace) {
slot_index_to_replace_ = slot_index_to_replace;
}
private:
size_t _capacity = 0;
std::mutex _mutex;
bool _full = false;
size_t _cur_size = 0;
size_t _total_size = 0;
std::vector<RecordCandidate> _candidate_list;
size_t capacity_ = 0;
std::mutex mutex_;
bool full_ = false;
size_t cur_size_ = 0;
size_t total_size_ = 0;
std::vector<RecordCandidate> candidate_list_;
std::unordered_set<uint16_t> slot_index_to_replace_;
};
template <class AR>

@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}
void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
void MultiSlotDataset::GetRandomData(
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
slots_shuffle_rclist_.ReInit();
for (const auto& rec : slots_shuffle_original_data_) {
const auto& slots_shuffle_original_data = GetSlotsOriginalData();
for (const auto& rec : slots_shuffle_original_data) {
RecordCandidate rand_rec;
Record new_rec = rec;
slots_shuffle_rclist_.AddAndGet(rec, &rand_rec);
@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas.equal_range(slot);
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
<< " repush feasign num: " << debug_push_cnt;
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
void MultiSlotDataset::PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slots) { // NOLINT
int out_channel_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle(
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<< input_channel_->Size()
<< " output channel size: " << out_channel_size;
if (!slots_shuffle_fea_eval_) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle";
return;
}
if ((!input_channel_ || input_channel_->Size() == 0) &&
slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle";
return;
}
platform::Timer timeline;
timeline.Start();
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::set<uint16_t> index_slots;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name();
if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) {
@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle(
}
CHECK(input_channel_->Size() == 0)
<< "input channel should be empty before slots shuffle";
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
PADDLE_ENFORCE_EQ(slots_shuffle_fea_eval_, true,
platform::errors::PreconditionNotMet(
"fea eval mode off, need to set on for slots shuffle"));
platform::Timer timeline;
timeline.Start();
std::unordered_set<uint16_t> index_slots;
PreprocessChannel(slots_to_replace, index_slots);
std::vector<Record> random_data;
random_data.clear();
// get slots shuffled random_data

@ -67,6 +67,7 @@ class Dataset {
virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
@ -108,10 +109,7 @@ class Dataset {
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle(int thread_num = -1) = 0;
// for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
@ -183,6 +181,9 @@ class DatasetImpl : public Dataset {
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual Channel<T> GetInputChannel() { return input_channel_; }
virtual void SetInputChannel(const Channel<T>& input_channel) {
input_channel_ = input_channel;
}
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
@ -192,6 +193,7 @@ class DatasetImpl : public Dataset {
return data_feed_desc_;
}
virtual int GetChannelNum() { return channel_num_; }
virtual bool EnablePvMerge() { return enable_pv_merge_; }
virtual std::vector<paddle::framework::DataFeed*> GetReaders();
virtual void CreateChannel();
virtual void RegisterClientToClientMsgHandler();
@ -202,8 +204,9 @@ class DatasetImpl : public Dataset {
virtual void LocalShuffle();
virtual void GlobalShuffle(int thread_num = -1);
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {}
virtual const std::vector<T>& GetSlotsOriginalData() {
return slots_shuffle_original_data_;
}
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
@ -293,9 +296,13 @@ class MultiSlotDataset : public DatasetImpl<Record> {
}
std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
}
virtual void PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slot); // NOLINT
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual void GetRandomData(
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
};

@ -255,6 +255,113 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
<< " s";
VLOG(3) << "End PushSparseGrad";
}
void BoxWrapper::GetRandomReplace(const std::vector<Record>& pass_data) {
VLOG(0) << "Begin GetRandomReplace";
size_t ins_num = pass_data.size();
replace_idx_.resize(ins_num);
for (auto& cand_list : random_ins_pool_list) {
cand_list.ReInitPass();
}
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, ins_num]() {
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomReplace begin for thread[" << tid
<< "], and process [" << start << ", " << end
<< "), total ins: " << ins_num;
auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
random_pool.AddAndGet(ins, replace_idx_[i]);
}
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
pass_done_semi_->Put(1);
VLOG(0) << "End GetRandomReplace";
}
void BoxWrapper::GetRandomData(
const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
VLOG(0) << "Begin GetRandomData";
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, &slots_to_replace,
result]() {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
size_t ins_num = pass_data.size();
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomData begin for thread[" << tid << "], and process ["
<< start << ", " << end << "), total ins: " << ins_num;
const auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
const RecordCandidate& rand_rec = random_pool.Get(replace_idx_[i]);
Record new_rec = ins;
for (auto it = new_rec.uint64_feasigns_.begin();
it != new_rec.uint64_feasigns_.end();) {
if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) {
it = new_rec.uint64_feasigns_.erase(it);
debug_erase_cnt += 1;
} else {
++it;
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
}
}
(*result)[i] = std::move(new_rec);
}
VLOG(3) << "thread[" << tid << "]: erase feasign num: " << debug_erase_cnt
<< " repush feasign num: " << debug_push_cnt;
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
VLOG(0) << "End GetRandomData";
}
void BoxWrapper::AddReplaceFeasign(boxps::PSAgentBase* p_agent,
int feed_pass_thread_num) {
VLOG(0) << "Enter AddReplaceFeasign Function";
int semi;
pass_done_semi_->Get(semi);
VLOG(0) << "Last Pass had updated random pool done. Begin AddReplaceFeasign";
std::vector<std::thread> threads;
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads.push_back(std::thread([this, tid, p_agent, feed_pass_thread_num]() {
VLOG(3) << "AddReplaceFeasign begin for thread[" << tid << "]";
for (size_t pool_id = tid; pool_id < random_ins_pool_list.size();
pool_id += feed_pass_thread_num) {
auto& random_pool = random_ins_pool_list[pool_id];
for (size_t i = 0; i < random_pool.Size(); ++i) {
auto& ins_candidate = random_pool.Get(i);
for (const auto& pair : ins_candidate.feas_) {
p_agent->AddKey(pair.second.uint64_feasign_, tid);
}
}
}
}));
}
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads[tid].join();
}
VLOG(0) << "End AddReplaceFeasign";
}
} // end namespace framework
} // end namespace paddle
#endif

File diff suppressed because it is too large Load Diff

@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() {
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(exe_scope);
@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() {
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(exe_scope);

@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) {
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("slots_shuffle", &framework::BoxHelper::SlotsShuffle,
py::call_guard<py::gil_scoped_release>());
} // end BoxHelper
@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) {
.def("initialize_gpu_and_load_model",
&framework::BoxWrapper::InitializeGPUAndLoadModel,
py::call_guard<py::gil_scoped_release>())
.def("initialize_auc_runner", &framework::BoxWrapper::InitializeAucRunner,
py::call_guard<py::gil_scoped_release>())
.def("init_metric", &framework::BoxWrapper::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
py::call_guard<py::gil_scoped_release>())
.def("flip_pass_flag", &framework::BoxWrapper::FlipPassFlag,
.def("flip_phase", &framework::BoxWrapper::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
py::call_guard<py::gil_scoped_release>())

@ -291,6 +291,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_fleet_send_sleep_seconds",
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")

@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset):
def _dynamic_adjust_after_train(self):
pass
def slots_shuffle(self, slots):
"""
Slots Shuffle
Slots Shuffle is a shuffle method in slots level, which is usually used
in sparse feature with large scale of instances. To compare the metric, i.e.
auc while doing slots shuffle on one or several slots with baseline to
evaluate the importance level of slots(features).
Args:
slots(list[string]): the set of slots(string) to do slots shuffle.
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
#suppose there is a slot 0
dataset.slots_shuffle(['0'])
"""
slots_set = set(slots)
self.boxps.slots_shuffle(slots_set)

@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase):
exe.run(fluid.default_startup_program())
datasets[0].load_into_memory()
datasets[0].begin_pass()
datasets[0].slots_shuffle([])
datasets[1].preload_into_memory()
exe.train_from_dataset(
program=fluid.default_main_program(),

@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase):
dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi")
dataset.enable_pv_merge()
thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12)
@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase):
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(10000, True)
dataset.set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"])
dataset.local_shuffle()
dataset.set_generate_unique_feasigns(True, 15)

Loading…
Cancel
Save