|
|
|
@ -287,10 +287,6 @@ void RecurrentGradientMachine::init(
|
|
|
|
|
parameterIds_.push_back(para->getID());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (subModelConfig->evaluator_names_size() > 0) {
|
|
|
|
|
evaluator_.reset(frames_[0]->makeEvaluator());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
|
|
|
|
@ -561,9 +557,6 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
|
|
|
|
|
std::vector<Argument> outArgs;
|
|
|
|
|
frames_[i]->forward(inArgs, &outArgs, passType);
|
|
|
|
|
}
|
|
|
|
|
if (evaluator_ && passType == PASS_TEST) {
|
|
|
|
|
this->eval(evaluator_.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
reorganizeOutput(passType);
|
|
|
|
|
}
|
|
|
|
@ -577,11 +570,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) {
|
|
|
|
|
for (auto& memoryFrameLine : memoryFrameLines_) {
|
|
|
|
|
memoryFrameLine.bootLayer->backward(nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// call printers here so the gradient can be printed
|
|
|
|
|
if (evaluator_) {
|
|
|
|
|
this->eval(evaluator_.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientMachine::forwardBackward(
|
|
|
|
@ -595,9 +583,9 @@ void RecurrentGradientMachine::forwardBackward(
|
|
|
|
|
void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
|
|
|
|
|
// call printers frame by frame
|
|
|
|
|
for (int i = 0; i < maxSequenceLength_; ++i) {
|
|
|
|
|
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin";
|
|
|
|
|
VLOG(2) << "Recurrent Layer Group eval frame " << i << " begin";
|
|
|
|
|
evaluator->eval(*(frames_[i].get()));
|
|
|
|
|
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " end";
|
|
|
|
|
VLOG(2) << "Recurrent Layer Group eval frame " << i << " end";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1093,10 +1081,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
|
|
|
|
|
|
|
|
|
|
copyDataOutlinkFrame(machineCur);
|
|
|
|
|
|
|
|
|
|
// call value printer
|
|
|
|
|
if (evaluator_) {
|
|
|
|
|
evaluator_->eval(*(frames_[machineCur].get()));
|
|
|
|
|
}
|
|
|
|
|
// check eos
|
|
|
|
|
const IVectorPtr& eosVec =
|
|
|
|
|
eosFrameLine_->layers[machineCur]->getOutput().ids;
|
|
|
|
|