Merge pull request #1187 from qingqing01/gru_bug

Bug fix in GatedRecurrentLayer which only occurs in predicting or testing.
avx_docs
Tao Luo 8 years ago committed by GitHub
commit d7ee239f7a

@ -314,13 +314,13 @@ void GatedRecurrentLayer::forwardBatch(int batchSize,
batchValue_->resizeOrCreate(*output_.value); batchValue_->resizeOrCreate(*output_.value);
batchValue_->copy(*inputValue, *gate_.value, /* seq2batch */ true); batchValue_->copy(*inputValue, *gate_.value, /* seq2batch */ true);
if (bias_ && bias_->getWGrad()) { if (bias_) {
gate_.value->addBias(*(bias_->getW()), 1); gate_.value->addBias(*(bias_->getW()), 1);
} }
{ {
int numBatch = batchValue_->getNumBatch(); int numBatch = batchValue_->getNumBatch();
int batchSize = 0; int curBatchSize = 0;
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
for (int n = 0; n < numBatch; n++) { for (int n = 0; n < numBatch; n++) {
MatrixPtr outputValueTmp = batchValue_->getBatchValue(n); MatrixPtr outputValueTmp = batchValue_->getBatchValue(n);
@ -330,16 +330,17 @@ void GatedRecurrentLayer::forwardBatch(int batchSize,
gruValue.resetOutputValue = gruValue.resetOutputValue =
(batchValue_->getBatchValue(*resetOutput_.value, n))->getData(); (batchValue_->getBatchValue(*resetOutput_.value, n))->getData();
batchSize = outputValueTmp->getHeight(); curBatchSize = outputValueTmp->getHeight();
gruValue.prevOutValue = gruValue.prevOutValue =
(n == 0 ? nullptr (n == 0
: (batchValue_->getBatchValue(n - 1, batchSize))->getData()); ? nullptr
: (batchValue_->getBatchValue(n - 1, curBatchSize))->getData());
{ {
if (useGpu_) { if (useGpu_) {
GruCompute::forward<1>(gruValue, getSize(), batchSize); GruCompute::forward<1>(gruValue, getSize(), curBatchSize);
} else { } else {
GruCompute::forward<0>(gruValue, getSize(), batchSize); GruCompute::forward<0>(gruValue, getSize(), curBatchSize);
} }
} }
} }

Loading…
Cancel
Save