|
|
|
@ -22,6 +22,12 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::Initialize(const TrainerDesc& desc) {
|
|
|
|
|
fetch_config_ = desc.fetch_config();
|
|
|
|
|
param_ = desc.hogwild_param();
|
|
|
|
|
skip_ops_.resize(param_.skip_ops_size());
|
|
|
|
|
LOG(WARNING) << "skip op size: " << skip_ops_.size();
|
|
|
|
|
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
|
|
|
|
|
skip_ops_[i] = param_.skip_ops(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
|
|
|
|
@ -92,9 +98,18 @@ void HogwildWorker::TrainFilesWithProfiler() {
|
|
|
|
|
read_time += timeline.ElapsedSec();
|
|
|
|
|
total_time += timeline.ElapsedSec();
|
|
|
|
|
for (size_t i = 0; i < ops_.size(); ++i) {
|
|
|
|
|
bool need_skip = false;
|
|
|
|
|
for (auto t = 0u; t < skip_ops_.size(); ++t) {
|
|
|
|
|
if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
|
|
|
|
|
need_skip = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
timeline.Start();
|
|
|
|
|
VLOG(3) << "Going to run op " << op_name[i];
|
|
|
|
|
ops_[i]->Run(*thread_scope_, place_);
|
|
|
|
|
if (!need_skip) {
|
|
|
|
|
ops_[i]->Run(*thread_scope_, place_);
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "Op " << op_name[i] << " Finished";
|
|
|
|
|
timeline.Pause();
|
|
|
|
|
op_total_time[i] += timeline.ElapsedSec();
|
|
|
|
@ -127,7 +142,16 @@ void HogwildWorker::TrainFiles() {
|
|
|
|
|
int cur_batch;
|
|
|
|
|
while ((cur_batch = device_reader_->Next()) > 0) {
|
|
|
|
|
for (auto& op : ops_) {
|
|
|
|
|
op->Run(*thread_scope_, place_);
|
|
|
|
|
bool need_skip = false;
|
|
|
|
|
for (auto t = 0u; t < skip_ops_.size(); ++t) {
|
|
|
|
|
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
|
|
|
|
|
need_skip = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!need_skip) {
|
|
|
|
|
op->Run(*thread_scope_, place_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PrintFetchVars();
|
|
|
|
|