|
|
|
@ -41,6 +41,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
mpi_rank_ = trainer_desc.mpi_rank() / 2;
|
|
|
|
|
mpi_size_ = trainer_desc.mpi_size() / 2;
|
|
|
|
|
dump_file_num_ = trainer_desc.dump_file_num();
|
|
|
|
|
const std::vector<paddle::framework::DataFeed *> readers =
|
|
|
|
|
dataset->GetReaders();
|
|
|
|
|
|
|
|
|
@ -68,20 +70,25 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
|
|
|
|
|
SetDebug(trainer_desc.debug());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DistMultiTrainer::DumpWork() {
|
|
|
|
|
void DistMultiTrainer::DumpWork(int tid) {
|
|
|
|
|
#ifdef _LINUX
|
|
|
|
|
int err_no = 0;
|
|
|
|
|
std::string path = string::format_string(
|
|
|
|
|
"%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
|
|
|
|
|
while (1) {
|
|
|
|
|
std::string out_str;
|
|
|
|
|
if (!queue_->Get(out_str)) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
size_t write_count =
|
|
|
|
|
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp_.get());
|
|
|
|
|
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get());
|
|
|
|
|
if (write_count != out_str.length()) {
|
|
|
|
|
VLOG(3) << "dump text failed";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
write_count = fwrite_unlocked("\n", 1, 1, fp_.get());
|
|
|
|
|
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
|
|
|
|
|
if (write_count != 1) {
|
|
|
|
|
VLOG(3) << "dump text failed";
|
|
|
|
|
continue;
|
|
|
|
@ -92,20 +99,27 @@ void DistMultiTrainer::DumpWork() {
|
|
|
|
|
|
|
|
|
|
void DistMultiTrainer::InitDumpEnv() {
|
|
|
|
|
queue_ = paddle::framework::MakeChannel<std::string>();
|
|
|
|
|
int err_no = 0;
|
|
|
|
|
std::string path = string::format_string(
|
|
|
|
|
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
|
|
|
|
|
|
|
|
|
|
fp_ = fs_open_write(path, &err_no, dump_converter_);
|
|
|
|
|
for (int i = 0; i < thread_num_; ++i) {
|
|
|
|
|
workers_[i]->SetChannelWriter(queue_.get());
|
|
|
|
|
}
|
|
|
|
|
dump_thread_ = std::thread(&DistMultiTrainer::DumpWork, this);
|
|
|
|
|
dump_thread_num_ = 1;
|
|
|
|
|
if (dump_file_num_ > mpi_size_) {
|
|
|
|
|
dump_thread_num_ = dump_file_num_ / mpi_size_;
|
|
|
|
|
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
|
|
|
|
|
dump_thread_num_ += 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < dump_thread_num_; i++) {
|
|
|
|
|
dump_thread_.push_back(
|
|
|
|
|
std::thread(std::bind(&DistMultiTrainer::DumpWork, this, i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DistMultiTrainer::FinalizeDumpEnv() {
|
|
|
|
|
queue_->Close();
|
|
|
|
|
dump_thread_.join();
|
|
|
|
|
for (auto &th : dump_thread_) {
|
|
|
|
|
th.join();
|
|
|
|
|
}
|
|
|
|
|
queue_.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|