|
|
|
@ -21,25 +21,34 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
|
|
|
|
|
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
|
|
|
|
|
Dataset* dataset) {
|
|
|
|
|
thread_num_ = trainer_desc.thread_num();
|
|
|
|
|
// get filelist from trainer_desc here
|
|
|
|
|
workers_.resize(thread_num_);
|
|
|
|
|
readers_.resize(thread_num_);
|
|
|
|
|
|
|
|
|
|
if (NULL == dataset) {
|
|
|
|
|
readers_.resize(thread_num_);
|
|
|
|
|
for (int i = 0; i < thread_num_; ++i) {
|
|
|
|
|
readers_[i] =
|
|
|
|
|
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
|
|
|
|
|
readers_[i]->Init(trainer_desc.data_desc());
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> filelist_vec;
|
|
|
|
|
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
|
|
|
|
|
filelist_vec.push_back(trainer_desc.filelist(i));
|
|
|
|
|
}
|
|
|
|
|
readers_[0]->SetFileList(filelist_vec);
|
|
|
|
|
} else {
|
|
|
|
|
// readers_ = dataset.get_readers(); ?
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < thread_num_; ++i) {
|
|
|
|
|
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
|
|
|
|
|
trainer_desc.device_worker_name());
|
|
|
|
|
readers_[i] =
|
|
|
|
|
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
|
|
|
|
|
workers_[i]->SetDeviceIndex(i);
|
|
|
|
|
readers_[i]->Init(trainer_desc.data_desc());
|
|
|
|
|
workers_[i]->SetDataFeed(readers_[i]);
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> filelist_vec;
|
|
|
|
|
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
|
|
|
|
|
filelist_vec.push_back(trainer_desc.filelist(i));
|
|
|
|
|
}
|
|
|
|
|
readers_[0]->SetFileList(filelist_vec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// call only after all resources are set in current trainer
|
|
|
|
|