|
|
|
@ -177,11 +177,20 @@ struct CudnnRNNCache {
|
|
|
|
|
seed_));
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION >= 6000
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
|
|
|
|
|
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
|
|
|
|
|
CUDNN_LINEAR_INPUT,
|
|
|
|
|
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
|
|
|
|
|
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor(
|
|
|
|
|
rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
|
|
|
|
|
CUDNN_LINEAR_INPUT,
|
|
|
|
|
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
|
|
|
|
|
CUDNN_DATA_FLOAT));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
|
|
|
|
|