|
|
|
@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) {
|
|
|
|
|
size_t outDim = getSize();
|
|
|
|
|
|
|
|
|
|
size_t numSequences = input.getNumSequences();
|
|
|
|
|
auto startPositions = input.sequenceStartPositions->getVector(false);
|
|
|
|
|
const int* starts = startPositions->getData();
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(starts[numSequences], input.getBatchSize());
|
|
|
|
|
CHECK_EQ(numSequences, startPositions->getSize() - 1);
|
|
|
|
|
// by default, we assume each instance as a sequence
|
|
|
|
|
IVectorPtr seqStarts;
|
|
|
|
|
IVector::resizeOrCreate(seqStarts, input.getBatchSize() + 1, false);
|
|
|
|
|
int* startsData = seqStarts->getData();
|
|
|
|
|
for (int i = 0; i < input.getBatchSize() + 1; i++) {
|
|
|
|
|
startsData[i] = i;
|
|
|
|
|
}
|
|
|
|
|
const int* starts = startsData;
|
|
|
|
|
|
|
|
|
|
// if there is sequence, then use start positions
|
|
|
|
|
if (input.sequenceStartPositions) {
|
|
|
|
|
auto startPositions = input.sequenceStartPositions->getVector(false);
|
|
|
|
|
starts = startPositions->getData();
|
|
|
|
|
CHECK_EQ(starts[numSequences], input.getBatchSize());
|
|
|
|
|
CHECK_EQ(numSequences, startPositions->getSize() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t seqID = 0; seqID < numSequences; seqID++) {
|
|
|
|
|
size_t inNumIns = starts[seqID + 1] - starts[seqID];
|
|
|
|
|