Fix MultiGradientMachine error

avx_docs
hedaoyuan 8 years ago
parent ca62c104ec
commit b1c22b6790

@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
void MultiGradientMachine::eval(Evaluator* evaluator) const {
for (auto& thread : threads_) {
SetDevice device(thread->getDeviceId());
thread->getGradientMachine()->eval(evaluator);
if (thread->hasInputData()) {
thread->getGradientMachine()->eval(evaluator);
}
}
}
@ -356,14 +358,20 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
REGISTER_TIMER("waitOutArgs");
thread->waitOutArgsReady();
}
outArgs_.resize(threads_[0]->getOutArgs().size());
// outArgs_.size() only need to be calculated once.
static int size = threads_[threads_.size() - 1]->getOutArgs().size();
outArgs_.resize(size);
REGISTER_TIMER("copyOutArgs");
for (size_t i = 0; i < outArgs_.size(); ++i) {
std::vector<Argument> args;
args.reserve(threads_.size());
for (auto& thread : threads_) {
args.push_back(thread->getOutArgs()[i]);
// If the thread input is empty, then the output is empty.
auto tmp = thread->getOutArgs();
if (tmp.size() > 0) {
args.push_back(tmp[i]);
}
}
outArgs_[i].concat(args, useGpu_, outArgStream_, passType);
}
@ -534,7 +542,7 @@ void TrainerThread::prefetch() {
void TrainerThread::forward() {
if (!inArgsCopied_) {
REGISTER_TIMER("copyInArgs");
copyInArgs();
batchSize_ = copyInArgs();
} else {
inArgsCopied_ = false;
}
@ -564,7 +572,12 @@ void TrainerThread::forward() {
{
REGISTER_TIMER("thread_forward");
gradientMachine_->forward(inArgs_, &outArgs_, multiMachine_->getPassType());
if (batchSize_ > 0) {
gradientMachine_->forward(
inArgs_, &outArgs_, multiMachine_->getPassType());
} else {
outArgs_.clear();
}
}
outArgsReadySem_.post();
}
@ -574,7 +587,13 @@ void TrainerThread::backward() {
if (multiMachine_->isPassGrad()) {
copyOutputGrad();
}
gradientMachine_->backward(backwardCallback_);
if (batchSize_ > 0) {
gradientMachine_->backward(backwardCallback_);
} else {
for (size_t i = parameters_.size(); i > 0; i--) {
backwardCallback(parameters_[i - 1].get());
}
}
if (multiMachine_->hasNonstaticCpuParamters()) {
mergeCpuGradients();
}
@ -732,7 +751,7 @@ void TrainerThread::notifyValueReady(int paramId) {
notifyValueDispatch(paramId);
}
void TrainerThread::copyInArgs() {
int TrainerThread::copyInArgs() {
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs();
int numThreads = multiMachine_->getAllThreads().size();
int32_t numSequences = fullInArgs[0].getNumSequences();
@ -748,7 +767,7 @@ void TrainerThread::copyInArgs() {
}
if (copySize == 0) {
return;
return 0;
}
for (size_t i = 0; i < fullInArgs.size(); i++) {
@ -758,6 +777,7 @@ void TrainerThread::copyInArgs() {
copySize,
FLAGS_parallel_nn ? false : multiMachine_->useGpu());
}
return copySize;
}
void TrainerThread::mergeCpuGradients() {

@ -387,6 +387,9 @@ public:
/// copy the output gradient from the main GradientMachine.
void copyOutputGrad();
/// Whether the thread has input data.
bool hasInputData() { return batchSize_ != 0; }
protected:
void mergeCpuGradients();
@ -407,7 +410,7 @@ protected:
void copyGradToBufferThread();
void gradCollectThread();
void copyInArgs();
int copyInArgs();
void forward();
void backward();
void backwardCallback(Parameter* para);
@ -467,6 +470,8 @@ protected:
/// indicate whether inArgs is copied before forward()
bool inArgsCopied_;
int batchSize_;
};
} // namespace paddle

Loading…
Cancel
Save