|
|
@ -288,10 +288,6 @@ void RecurrentGradientMachine::init(
|
|
|
|
parameterIds_.push_back(para->getID());
|
|
|
|
parameterIds_.push_back(para->getID());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (subModelConfig->evaluator_names_size() > 0) {
|
|
|
|
|
|
|
|
evaluator_.reset(frames_[0]->makeEvaluator());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
|
|
|
|
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
|
|
|
@ -562,9 +558,6 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
|
|
|
|
std::vector<Argument> outArgs;
|
|
|
|
std::vector<Argument> outArgs;
|
|
|
|
frames_[i]->forward(inArgs, &outArgs, passType);
|
|
|
|
frames_[i]->forward(inArgs, &outArgs, passType);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (evaluator_ && passType == PASS_TEST) {
|
|
|
|
|
|
|
|
this->eval(evaluator_.get());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reorganizeOutput(passType);
|
|
|
|
reorganizeOutput(passType);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -581,11 +574,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) {
|
|
|
|
for (auto& memoryFrameLine : memoryFrameLines_) {
|
|
|
|
for (auto& memoryFrameLine : memoryFrameLines_) {
|
|
|
|
memoryFrameLine.bootLayer->backward(nullptr);
|
|
|
|
memoryFrameLine.bootLayer->backward(nullptr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// call printers here so the gradient can be printed
|
|
|
|
|
|
|
|
if (evaluator_) {
|
|
|
|
|
|
|
|
this->eval(evaluator_.get());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientMachine::forwardBackward(
|
|
|
|
void RecurrentGradientMachine::forwardBackward(
|
|
|
@ -599,9 +587,9 @@ void RecurrentGradientMachine::forwardBackward(
|
|
|
|
void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
|
|
|
|
void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
|
|
|
|
// call printers frame by frame
|
|
|
|
// call printers frame by frame
|
|
|
|
for (int i = 0; i < maxSequenceLength_; ++i) {
|
|
|
|
for (int i = 0; i < maxSequenceLength_; ++i) {
|
|
|
|
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin";
|
|
|
|
VLOG(2) << "Recurrent Layer Group eval frame " << i << " begin";
|
|
|
|
evaluator->eval(*(frames_[i].get()));
|
|
|
|
evaluator->eval(*(frames_[i].get()));
|
|
|
|
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " end";
|
|
|
|
VLOG(2) << "Recurrent Layer Group eval frame " << i << " end";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -1097,10 +1085,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
|
|
|
|
|
|
|
|
|
|
|
|
copyDataOutlinkFrame(machineCur);
|
|
|
|
copyDataOutlinkFrame(machineCur);
|
|
|
|
|
|
|
|
|
|
|
|
// call value printer
|
|
|
|
|
|
|
|
if (evaluator_) {
|
|
|
|
|
|
|
|
evaluator_->eval(*(frames_[machineCur].get()));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// check eos
|
|
|
|
// check eos
|
|
|
|
const IVectorPtr& eosVec =
|
|
|
|
const IVectorPtr& eosVec =
|
|
|
|
eosFrameLine_->layers[machineCur]->getOutput().ids;
|
|
|
|
eosFrameLine_->layers[machineCur]->getOutput().ids;
|
|
|
|