fix MultiGradientMachine train and infer

gangliao-patch-1
qiaolongfei 8 years ago
parent 7573e9eba4
commit 55684af208

@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
}
}
MultiGradientMachine::~MultiGradientMachine() {
for (auto& thread : threads_) {
thread->stop();
}
}
std::vector<const std::vector<ParameterPtr>*>
MultiGradientMachine::getSlaveParameters() {
std::vector<const std::vector<ParameterPtr>*> vec;
@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() {
}
}
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator();
}

@ -176,6 +176,8 @@ public:
explicit MultiGradientMachine(const ModelConfig& config, bool useGpu);
virtual ~MultiGradientMachine();
virtual void prefetch(const std::vector<Argument>& inArgs);
virtual void forward(const std::vector<Argument>& inArgs,
@ -193,8 +195,6 @@ public:
virtual void onPassEnd();
virtual void finish();
virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator) const;

Loading…
Cancel
Save