|
|
|
@ -89,15 +89,16 @@ class RNNDescriptors {
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn dropout descriptors ---------------------
|
|
|
|
|
size_t state_size;
|
|
|
|
|
if (!is_test_ && !dropout_state->IsInitialized()) {
|
|
|
|
|
bool is_initialized = dropout_state->IsInitialized();
|
|
|
|
|
if (!is_test_ && !is_initialized) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
|
|
|
|
|
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
|
|
|
|
|
place);
|
|
|
|
|
}
|
|
|
|
|
dropout_desc_.descriptor(handle, place, dropout_state->IsInitialized(),
|
|
|
|
|
dropout_prob_, is_test_ ? nullptr : dropout_state,
|
|
|
|
|
seed_, state_size);
|
|
|
|
|
dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_,
|
|
|
|
|
is_test_ ? nullptr : dropout_state, seed_,
|
|
|
|
|
state_size);
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn rnn descriptors ---------------------
|
|
|
|
|
#if CUDNN_VERSION >= 6000
|
|
|
|
|