|
|
|
@ -130,6 +130,8 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
|
|
|
|
|
CHECK(starts || ends) << "At least one of the start or end indices "
|
|
|
|
|
<< "should be given.";
|
|
|
|
|
|
|
|
|
|
bool hasSubseq = getInput(0).hasSubseq();
|
|
|
|
|
|
|
|
|
|
outSeqStartPos_.resize(1, 0);
|
|
|
|
|
outSubSeqStartPos_.resize(1, 0);
|
|
|
|
|
selectedRows_.clear();
|
|
|
|
@ -151,14 +153,13 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
|
|
|
|
|
int seqLen = endPos - begPos + 1;
|
|
|
|
|
CHECK_GT(seqLen, 0U);
|
|
|
|
|
for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m);
|
|
|
|
|
inputSeqInfoVec_.size() > 1
|
|
|
|
|
hasSubseq
|
|
|
|
|
? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen)
|
|
|
|
|
: outSeqStartPos_.push_back(outSeqStartPos_.back() + seqLen);
|
|
|
|
|
}
|
|
|
|
|
rowIdx++;
|
|
|
|
|
}
|
|
|
|
|
if (inputSeqInfoVec_.size() > 1)
|
|
|
|
|
outSeqStartPos_.push_back(outSubSeqStartPos_.back());
|
|
|
|
|
if (hasSubseq) outSeqStartPos_.push_back(outSubSeqStartPos_.back());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (useGpu_) {
|
|
|
|
@ -175,7 +176,7 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
|
|
|
|
|
output_.sequenceStartPositions->copyFrom(
|
|
|
|
|
outSeqStartPos_.data(), outSeqStartPos_.size(), false);
|
|
|
|
|
|
|
|
|
|
if (inputSeqInfoVec_.size() > 1) {
|
|
|
|
|
if (hasSubseq) {
|
|
|
|
|
ICpuGpuVector::resizeOrCreate(
|
|
|
|
|
output_.subSequenceStartPositions, outSubSeqStartPos_.size(), false);
|
|
|
|
|
output_.subSequenceStartPositions->copyFrom(
|
|
|
|
@ -204,10 +205,11 @@ void SequenceSliceLayer::forward(PassType passType) {
|
|
|
|
|
copySliceIdsToCpu();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// calculate the selected row indices in a batch,
|
|
|
|
|
// and build the output sequence information.
|
|
|
|
|
calSelectedRows(startIdsOnCpu_ ? startIdsOnCpu_ : nullptr,
|
|
|
|
|
endIdsOnCpu_ ? endIdsOnCpu_ : nullptr);
|
|
|
|
|
/*
|
|
|
|
|
* calculate the selected row indices in a batch, and build the output
|
|
|
|
|
* sequence information.
|
|
|
|
|
*/
|
|
|
|
|
calSelectedRows(startIdsOnCpu_, endIdsOnCpu_);
|
|
|
|
|
|
|
|
|
|
resetOutput(selectedRows_.size(), getSize());
|
|
|
|
|
|
|
|
|
|