|
|
|
@ -73,9 +73,10 @@ void SequenceSliceLayer::checkInputs() {
|
|
|
|
CHECK(inputSeq.hasSeq()) << "The first input of sequence slice layer "
|
|
|
|
CHECK(inputSeq.hasSeq()) << "The first input of sequence slice layer "
|
|
|
|
<< "must be a sequence.";
|
|
|
|
<< "must be a sequence.";
|
|
|
|
const MatrixPtr indices1 = getInputValue(1);
|
|
|
|
const MatrixPtr indices1 = getInputValue(1);
|
|
|
|
CHECK_EQ(static_cast<size_t>(indices1->getHeight()),
|
|
|
|
CHECK_EQ(
|
|
|
|
inputSeq.hasSubseq() ? inputSeq.getNumSubSequences()
|
|
|
|
indices1->getHeight(),
|
|
|
|
: inputSeq.getNumSequences())
|
|
|
|
static_cast<size_t>(inputSeq.hasSubseq() ? inputSeq.getNumSubSequences()
|
|
|
|
|
|
|
|
: inputSeq.getNumSequences()))
|
|
|
|
<< "Height of the second input should be equal to number of sequence "
|
|
|
|
<< "Height of the second input should be equal to number of sequence "
|
|
|
|
<< "in the first input.";
|
|
|
|
<< "in the first input.";
|
|
|
|
if (inputLayers_.size() == 3) {
|
|
|
|
if (inputLayers_.size() == 3) {
|
|
|
|
@ -151,7 +152,7 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
|
|
|
|
if (ends) endPos = inputSeqInfoVec_[i][j] + ends->getElement(rowIdx, k);
|
|
|
|
if (ends) endPos = inputSeqInfoVec_[i][j] + ends->getElement(rowIdx, k);
|
|
|
|
|
|
|
|
|
|
|
|
int seqLen = endPos - begPos + 1;
|
|
|
|
int seqLen = endPos - begPos + 1;
|
|
|
|
CHECK_GT(seqLen, 0U);
|
|
|
|
CHECK_GT(seqLen, 0);
|
|
|
|
for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m);
|
|
|
|
for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m);
|
|
|
|
hasSubseq
|
|
|
|
hasSubseq
|
|
|
|
? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen)
|
|
|
|
? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen)
|
|
|
|
|