|
|
|
@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
real Trainer::checkGradient() {
|
|
|
|
|
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
|
|
|
|
|
trainerInternal_.getGradientMachine()->start();
|
|
|
|
|
std::vector<ParameterPtr>& parameters =
|
|
|
|
|
trainerInternal_.getGradientMachine()->getNonStaticParameters();
|
|
|
|
|
DataBatch dataBatch;
|
|
|
|
@ -390,7 +390,7 @@ void Trainer::startTrain() {
|
|
|
|
|
dataProvider_->reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
|
|
|
|
|
trainerInternal_.getGradientMachine()->start();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }
|
|
|
|
|