Fix the dropout setting when not initialized in rnn_op. (#28561)

test=develop
musl/fix_failed_unittests_in_musl
Guo Sheng 5 years ago committed by GitHub
parent f78211d082
commit 858ffa0c8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -89,15 +89,16 @@ class RNNDescriptors {
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
if (!is_test_ && !dropout_state->IsInitialized()) {
bool is_initialized = dropout_state->IsInitialized();
if (!is_test_ && !is_initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
}
dropout_desc_.descriptor(handle, place, dropout_state->IsInitialized(),
dropout_prob_, is_test_ ? nullptr : dropout_state,
seed_, state_size);
dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_,
is_test_ ? nullptr : dropout_state, seed_,
state_size);
// ------------------- cudnn rnn descriptors ---------------------
#if CUDNN_VERSION >= 6000

Loading…
Cancel
Save