|
|
|
@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
|
|
|
|
|
|
|
|
|
|
outArgStream_ = HPPL_STREAM_1;
|
|
|
|
|
|
|
|
|
|
start();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiGradientMachine::start() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->start();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiGradientMachine::finish() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->stop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<const std::vector<ParameterPtr>*>
|
|
|
|
|
MultiGradientMachine::getSlaveParameters() {
|
|
|
|
|
std::vector<const std::vector<ParameterPtr>*> vec;
|
|
|
|
@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiGradientMachine::finish() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->stop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Evaluator* MultiGradientMachine::makeEvaluator() const {
|
|
|
|
|
return threads_[0]->getGradientMachine()->makeEvaluator();
|
|
|
|
|
}
|
|
|
|
@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
|
|
|
|
|
|
|
|
|
|
gradStream_ = HPPL_STREAM_2;
|
|
|
|
|
valueStream_ = HPPL_STREAM_3;
|
|
|
|
|
stopping_ = false;
|
|
|
|
|
stopping_ = true;
|
|
|
|
|
updateCounter_ = 0;
|
|
|
|
|
parameterUpdated_ = false;
|
|
|
|
|
}
|
|
|
|
@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config,
|
|
|
|
|
TrainerThread::~TrainerThread() { stop(); }
|
|
|
|
|
|
|
|
|
|
void TrainerThread::start() {
|
|
|
|
|
if (!stopping_) return;
|
|
|
|
|
|
|
|
|
|
stopping_ = false;
|
|
|
|
|
|
|
|
|
|
gradientMachine_->start();
|
|
|
|
|
|
|
|
|
|
computeThread_.reset(new std::thread([this]() { computeThread(); }));
|
|
|
|
|