|
|
|
@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MultiGradientMachine::~MultiGradientMachine() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->stop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<const std::vector<ParameterPtr>*>
|
|
|
|
|
MultiGradientMachine::getSlaveParameters() {
|
|
|
|
|
std::vector<const std::vector<ParameterPtr>*> vec;
|
|
|
|
@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiGradientMachine::finish() {
|
|
|
|
|
for (auto& thread : threads_) {
|
|
|
|
|
thread->stop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Evaluator* MultiGradientMachine::makeEvaluator() const {
|
|
|
|
|
return threads_[0]->getGradientMachine()->makeEvaluator();
|
|
|
|
|
}
|
|
|
|
|