|
|
|
@ -166,12 +166,16 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
|
|
|
|
|
|
|
|
|
|
outArgStream_ = HPPL_STREAM_1;
|
|
|
|
|
|
|
|
|
|
start();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiGradientMachine::start() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->start();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MultiGradientMachine::~MultiGradientMachine() {
|
|
|
|
|
void MultiGradientMachine::finish() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->stop();
|
|
|
|
|
}
|
|
|
|
@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
|
|
|
|
|
|
|
|
|
|
gradStream_ = HPPL_STREAM_2;
|
|
|
|
|
valueStream_ = HPPL_STREAM_3;
|
|
|
|
|
stopping_ = false;
|
|
|
|
|
stopping_ = true;
|
|
|
|
|
updateCounter_ = 0;
|
|
|
|
|
parameterUpdated_ = false;
|
|
|
|
|
}
|
|
|
|
@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config,
|
|
|
|
|
TrainerThread::~TrainerThread() { stop(); }
|
|
|
|
|
|
|
|
|
|
void TrainerThread::start() {
|
|
|
|
|
if (!stopping_) return;
|
|
|
|
|
|
|
|
|
|
stopping_ = false;
|
|
|
|
|
|
|
|
|
|
gradientMachine_->start();
|
|
|
|
|
|
|
|
|
|
computeThread_.reset(new std::thread([this]() { computeThread(); }));
|
|
|
|
|