|
|
@ -38,7 +38,7 @@ public:
|
|
|
|
virtual void startPass() {}
|
|
|
|
virtual void startPass() {}
|
|
|
|
|
|
|
|
|
|
|
|
// called by Trainer then finishing a pass, ruturn true if pass accepted
|
|
|
|
// called by Trainer then finishing a pass, ruturn true if pass accepted
|
|
|
|
virtual bool finishPass(real cost = 0) { return true; }
|
|
|
|
virtual bool finishPass() { return true; }
|
|
|
|
|
|
|
|
|
|
|
|
// called by Trainer before backward() of a batch
|
|
|
|
// called by Trainer before backward() of a batch
|
|
|
|
// Return the type of pass it needs. This pass type will be passed
|
|
|
|
// Return the type of pass it needs. This pass type will be passed
|
|
|
@ -112,9 +112,9 @@ public:
|
|
|
|
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
|
|
|
|
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
virtual bool finishPass(real cost = 0) {
|
|
|
|
virtual bool finishPass() {
|
|
|
|
syncThreadPool_->execPlusOwner(
|
|
|
|
syncThreadPool_->execPlusOwner(
|
|
|
|
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); });
|
|
|
|
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); });
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|