|
|
|
@ -358,6 +358,7 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
|
|
|
|
|
REGISTER_TIMER("waitOutArgs");
|
|
|
|
|
thread->waitOutArgsReady();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// outArgs_.size() only need to be calculated once.
|
|
|
|
|
static int size = threads_[threads_.size() - 1]->getOutArgs().size();
|
|
|
|
|
outArgs_.resize(size);
|
|
|
|
@ -574,9 +575,9 @@ void TrainerThread::forward() {
|
|
|
|
|
REGISTER_TIMER("thread_forward");
|
|
|
|
|
if (batchSize_ > 0) {
|
|
|
|
|
gradientMachine_->forward(
|
|
|
|
|
inArgs_, &outArgs_, multiMachine_->getPassType());
|
|
|
|
|
inArgs_, &outArgs_, multiMachine_->getPassType());
|
|
|
|
|
} else {
|
|
|
|
|
outArgs_.clear();
|
|
|
|
|
outArgs_.clear();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
outArgsReadySem_.post();
|
|
|
|
|