|
|
|
@ -544,6 +544,12 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
|
|
|
|
|
const std::vector<Argument> inArgs;
|
|
|
|
|
std::vector<Argument> outArgs;
|
|
|
|
|
frames_[i]->forward(inArgs, &outArgs, passType);
|
|
|
|
|
if (hasSubseq) {
|
|
|
|
|
for (auto& outFrameLine : outFrameLines_) {
|
|
|
|
|
CHECK(outFrameLine.frames[i]->getOutput().sequenceStartPositions)
|
|
|
|
|
<< "In hierachical RNN, all out links should be from sequences.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (evaluator_ && passType == PASS_TEST) {
|
|
|
|
|
this->eval(evaluator_.get());
|
|
|
|
@ -635,16 +641,15 @@ void RecurrentGradientMachine::createInFrameInfo(int inlinkId,
|
|
|
|
|
std::vector<int> sequenceStartPositions;
|
|
|
|
|
const int* subSequenceStartPositions = nullptr;
|
|
|
|
|
|
|
|
|
|
if (hasSubseq) { // for sequenceScatterAgentLayer
|
|
|
|
|
subSequenceStartPositions =
|
|
|
|
|
input.subSequenceStartPositions->getData(false);
|
|
|
|
|
if (hasSubseq) { // for sequenceScatterAgentLayer
|
|
|
|
|
subSequenceStartPositions = input.subSequenceStartPositions->getData(false);
|
|
|
|
|
inlinkInfo->seqStartPosIndex.clear();
|
|
|
|
|
inlinkInfo->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0
|
|
|
|
|
}
|
|
|
|
|
// maxSequenceLength_: max topLevelLength in allsamples
|
|
|
|
|
for (int i = 0; i < maxSequenceLength_; ++i) {
|
|
|
|
|
if (hasSubseq) {
|
|
|
|
|
sequenceStartPositions.push_back(0); // first element = 0
|
|
|
|
|
sequenceStartPositions.push_back(0); // first element = 0
|
|
|
|
|
}
|
|
|
|
|
int numSeqs = 0;
|
|
|
|
|
for (size_t j = 0; j < numSequences; ++j) {
|
|
|
|
@ -676,9 +681,9 @@ void RecurrentGradientMachine::createInFrameInfo(int inlinkId,
|
|
|
|
|
}
|
|
|
|
|
if (hasSubseq) {
|
|
|
|
|
// inFrameLine create sequenceStartPositions one time
|
|
|
|
|
CHECK_EQ(sequenceStartPositions.size(),
|
|
|
|
|
static_cast<size_t>(maxSequenceLength_ +
|
|
|
|
|
input.getNumSubSequences()));
|
|
|
|
|
CHECK_EQ(
|
|
|
|
|
sequenceStartPositions.size(),
|
|
|
|
|
static_cast<size_t>(maxSequenceLength_ + input.getNumSubSequences()));
|
|
|
|
|
CHECK_EQ(inlinkInfo->seqStartPosIndex.size(),
|
|
|
|
|
static_cast<size_t>(maxSequenceLength_ + 1));
|
|
|
|
|
createSeqPos(sequenceStartPositions, &inlinkInfo->sequenceStartPositions);
|
|
|
|
@ -1102,10 +1107,12 @@ size_t RecurrentGradientMachine::beamShrink(std::vector<Path>& newPaths,
|
|
|
|
|
newPaths.end(), Path::greaterPath);
|
|
|
|
|
newPaths.resize(totalExpandCount + minNewPathSize);
|
|
|
|
|
|
|
|
|
|
real minPathLogProb = std::min_element(newPaths.end() - minNewPathSize,
|
|
|
|
|
newPaths.end())->logProb;
|
|
|
|
|
real maxPathLogProb = std::max_element(newPaths.end() - minNewPathSize,
|
|
|
|
|
newPaths.end())->logProb;
|
|
|
|
|
real minPathLogProb =
|
|
|
|
|
std::min_element(newPaths.end() - minNewPathSize, newPaths.end())
|
|
|
|
|
->logProb;
|
|
|
|
|
real maxPathLogProb =
|
|
|
|
|
std::max_element(newPaths.end() - minNewPathSize, newPaths.end())
|
|
|
|
|
->logProb;
|
|
|
|
|
|
|
|
|
|
// Remove the already formed paths that are relatively short
|
|
|
|
|
finalPaths_[seqId].erase(
|
|
|
|
|