fix MultiGradientMachine train and infer

gangliao-patch-1
qiaolongfei 8 years ago
parent 7573e9eba4
commit 55684af208

@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
} }
} }
MultiGradientMachine::~MultiGradientMachine() {
for (auto& thread : threads_) {
thread->stop();
}
}
std::vector<const std::vector<ParameterPtr>*> std::vector<const std::vector<ParameterPtr>*>
MultiGradientMachine::getSlaveParameters() { MultiGradientMachine::getSlaveParameters() {
std::vector<const std::vector<ParameterPtr>*> vec; std::vector<const std::vector<ParameterPtr>*> vec;
@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() {
} }
} }
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
Evaluator* MultiGradientMachine::makeEvaluator() const { Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator(); return threads_[0]->getGradientMachine()->makeEvaluator();
} }

@ -176,6 +176,8 @@ public:
explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); explicit MultiGradientMachine(const ModelConfig& config, bool useGpu);
virtual ~MultiGradientMachine();
virtual void prefetch(const std::vector<Argument>& inArgs); virtual void prefetch(const std::vector<Argument>& inArgs);
virtual void forward(const std::vector<Argument>& inArgs, virtual void forward(const std::vector<Argument>& inArgs,
@ -193,8 +195,6 @@ public:
virtual void onPassEnd(); virtual void onPassEnd();
virtual void finish();
virtual Evaluator* makeEvaluator() const; virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator) const; virtual void eval(Evaluator* evaluator) const;

Loading…
Cancel
Save