|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/generator.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
@ -552,8 +553,12 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
auto *out_data = out->data<T>();
|
|
|
|
|
auto *out_grad_data = out_grad->data<T>();
|
|
|
|
|
// maybe need check exist
|
|
|
|
|
auto *in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
// need check exist
|
|
|
|
|
T *in_grad_data = nullptr;
|
|
|
|
|
if (in_grad) {
|
|
|
|
|
in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool has_seq_length = ctx.HasInput("SequenceLength");
|
|
|
|
|
std::vector<int> SequenceLength;
|
|
|
|
@ -583,40 +588,52 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
|
|
|
|
|
const uint8_t *reserve_data = reserve->data<uint8_t>();
|
|
|
|
|
|
|
|
|
|
if (!has_seq_length) {
|
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
|
|
|
|
|
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
|
|
|
|
|
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
|
|
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
|
|
|
|
|
rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
|
|
|
|
|
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data<T>(),
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
|
|
|
|
|
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
if (in_grad) {
|
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
|
|
|
|
|
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
|
|
|
|
|
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
|
|
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
|
|
|
|
|
rnn.init_c_desc(), init_c_grad_data,
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size,
|
|
|
|
|
const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
}
|
|
|
|
|
if (!weight_grad_list.empty()) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data<T>(),
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
|
|
|
|
|
weight_grad_data, const_cast<uint8_t *>(reserve_data),
|
|
|
|
|
reserve_size));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
|
// for train
|
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
|
|
|
|
|
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
|
|
|
|
|
out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
|
|
|
|
|
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
|
|
rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
|
|
|
|
|
rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size,
|
|
|
|
|
const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
|
|
|
|
|
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.y_seq_desc(), out->data<T>(),
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
|
|
|
|
|
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
if (in_grad) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
|
|
|
|
|
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data,
|
|
|
|
|
rnn.y_seq_desc(), out_grad_data, nullptr, nullptr,
|
|
|
|
|
rnn.last_h_desc(), last_h_grad_data, rnn.last_c_desc(),
|
|
|
|
|
last_c_grad_data, rnn.weight_desc(), weight_data, rnn.init_h_desc(),
|
|
|
|
|
init_h_data, rnn.init_c_desc(), init_c_data, rnn.x_seq_desc(),
|
|
|
|
|
in_grad_data, rnn.init_h_desc(), init_h_grad_data,
|
|
|
|
|
rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size,
|
|
|
|
|
const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!weight_grad_list.empty()) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnRNNBackwardWeightsEx(
|
|
|
|
|
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.y_seq_desc(),
|
|
|
|
|
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
|
|
|
|
|
rnn.weight_desc(), weight_grad_data,
|
|
|
|
|
const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW(platform::errors::Unavailable(
|
|
|
|
|
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
|
|
|
|
|