|
|
|
@ -54,6 +54,8 @@ class ScopedRNNBase {
|
|
|
|
|
x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x));
|
|
|
|
|
y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
|
if (!sequence_length.empty()) {
|
|
|
|
|
x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true,
|
|
|
|
|
sequence_length);
|
|
|
|
@ -61,6 +63,7 @@ class ScopedRNNBase {
|
|
|
|
|
hidden_size_ * numDirections, true,
|
|
|
|
|
sequence_length);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn hx, hy, cx, cy descriptors----------
|
|
|
|
|
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
|
|
|
|
@ -96,10 +99,13 @@ class ScopedRNNBase {
|
|
|
|
|
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
|
|
|
|
|
cudnn_type));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
|
if (!sequence_length.empty()) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
|
|
|
|
|
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn weights_size ---------------------
|
|
|
|
|
size_t weights_size_;
|
|
|
|
@ -125,8 +131,10 @@ class ScopedRNNBase {
|
|
|
|
|
}
|
|
|
|
|
cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); }
|
|
|
|
|
cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); }
|
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
|
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); }
|
|
|
|
|
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); }
|
|
|
|
|
#endif
|
|
|
|
|
cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
|
|
|
|
|
cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
|
|
|
|
|
cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
|
|
|
|
@ -151,8 +159,10 @@ class ScopedRNNBase {
|
|
|
|
|
|
|
|
|
|
platform::ScopedTensorDescriptor x_desc_;
|
|
|
|
|
platform::ScopedTensorDescriptor y_desc_;
|
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
|
platform::ScopedRNNTensorDescriptor x_seq_desc_;
|
|
|
|
|
platform::ScopedRNNTensorDescriptor y_seq_desc_;
|
|
|
|
|
#endif
|
|
|
|
|
platform::ScopedTensorDescriptor init_h_desc_;
|
|
|
|
|
platform::ScopedTensorDescriptor init_c_desc_;
|
|
|
|
|
platform::ScopedTensorDescriptor last_h_desc_;
|
|
|
|
|