fix the rnn mask memory bug for out of read (#30459)

* fix the rnn mask memory bug for out of read

* update the code for the rnn
revert-31562-mean
wawltor 5 years ago committed by GitHub
parent f090066e85
commit 3d49882e2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -960,9 +960,10 @@ class RNNCPUKernel : public framework::OpKernel<T> {
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
if (!dropout_mask->IsInitialized()) {
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
if (dropout_mask->IsInitialized()) {
if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
}
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
// init the output and allocate the memory
output->mutable_data<T>(ctx.GetPlace());

Loading…
Cancel
Save