|
|
@ -14,9 +14,14 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/generator.h"
|
|
|
|
#include "paddle/fluid/framework/generator.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
#include "paddle/fluid/operators/utils.h"
|
|
|
|
#include "paddle/fluid/operators/utils.h"
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/miopen_lstm_cache.h"
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace platform {
|
|
|
|
namespace platform {
|
|
|
@ -54,7 +59,7 @@ int size_sum(const std::vector<const Tensor *> &weight_list) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
|
|
|
|
void weight_to_tensor(const platform::Place &place, gpuStream_t stream,
|
|
|
|
const std::vector<const Tensor *> &weight_list,
|
|
|
|
const std::vector<const Tensor *> &weight_list,
|
|
|
|
Tensor *weight) {
|
|
|
|
Tensor *weight) {
|
|
|
|
auto weight_data = weight->data<T>();
|
|
|
|
auto weight_data = weight->data<T>();
|
|
|
@ -72,7 +77,7 @@ void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
|
|
|
|
void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream,
|
|
|
|
std::vector<Tensor *> *weight_grad,
|
|
|
|
std::vector<Tensor *> *weight_grad,
|
|
|
|
const std::vector<const Tensor *> &weight_input,
|
|
|
|
const std::vector<const Tensor *> &weight_input,
|
|
|
|
const Tensor *weight) {
|
|
|
|
const Tensor *weight) {
|
|
|
@ -92,23 +97,36 @@ void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
|
|
|
void LSTMInferece(const bool &has_seq_length, const miopenHandle_t &handle,
|
|
|
|
|
|
|
|
#else
|
|
|
|
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
|
|
|
|
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
|
|
|
|
|
|
|
|
#endif
|
|
|
|
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
|
|
|
|
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
|
|
|
|
const T *init_h_data, const T *init_c_data, const T *w_data,
|
|
|
|
const T *init_h_data, const T *init_c_data, const T *w_data,
|
|
|
|
T *out_data, T *last_h_data, T *last_c_data,
|
|
|
|
T *out_data, T *last_h_data, T *last_c_data,
|
|
|
|
framework::Tensor *workspace_data,
|
|
|
|
framework::Tensor *workspace_data,
|
|
|
|
const size_t &workspace_size) {
|
|
|
|
const size_t &workspace_size) {
|
|
|
|
if (!has_seq_length) {
|
|
|
|
if (!has_seq_length) {
|
|
|
|
// for inference
|
|
|
|
// for inference
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardInference(
|
|
|
|
|
|
|
|
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
|
|
|
|
|
|
|
|
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
|
|
|
|
|
|
|
|
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
|
|
|
|
|
|
|
|
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
|
|
|
|
|
|
|
|
workspace_data->data<uint8_t>(), workspace_size));
|
|
|
|
|
|
|
|
#else
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
|
|
|
|
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
|
|
|
|
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
|
|
|
|
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
|
|
|
|
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
|
|
|
|
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
|
|
|
|
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
|
|
|
|
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
|
|
|
|
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
|
|
|
|
workspace_data->data<uint8_t>(), workspace_size));
|
|
|
|
workspace_data->data<uint8_t>(), workspace_size));
|
|
|
|
|
|
|
|
#endif
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
|
|
|
|
// for inference
|
|
|
|
// for inference
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
|
|
|
@ -256,8 +274,17 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
|
|
|
|
last_c_data, &workspace_data_, workspace_size);
|
|
|
|
last_c_data, &workspace_data_, workspace_size);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (!has_seq_length) {
|
|
|
|
if (!has_seq_length) {
|
|
|
|
// for train
|
|
|
|
// for train
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardTraining(
|
|
|
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
|
|
|
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
|
|
|
|
|
rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
|
|
|
|
|
|
|
|
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
|
|
|
|
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
|
|
|
|
|
|
|
|
reserve_size));
|
|
|
|
|
|
|
|
#else
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
@ -265,8 +292,9 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
|
|
|
|
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
|
|
|
|
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
|
|
|
|
reserve_size));
|
|
|
|
reserve_size));
|
|
|
|
|
|
|
|
#endif
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
|
|
|
|
// for train
|
|
|
|
// for train
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
@ -403,7 +431,23 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
|
|
|
|
const uint8_t *reserve_data = reserve->data<uint8_t>();
|
|
|
|
const uint8_t *reserve_data = reserve->data<uint8_t>();
|
|
|
|
|
|
|
|
|
|
|
|
if (!has_seq_length) {
|
|
|
|
if (!has_seq_length) {
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
// This interface is used when the input/output is unpadded.
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardData(
|
|
|
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
|
|
|
|
|
|
|
|
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
|
|
|
|
|
|
|
|
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
|
|
|
|
|
|
|
|
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
|
|
|
|
|
|
|
|
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
|
|
|
|
|
|
|
|
rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
|
|
|
|
|
|
|
|
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
|
|
|
|
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
|
|
|
|
|
|
|
|
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
|
|
|
|
|
|
|
|
rnn.weight_desc(), weight_grad_data, workspace_data_.data<uint8_t>(),
|
|
|
|
|
|
|
|
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
|
|
|
#else
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
|
|
|
|
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
|
|
|
|
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
|
|
|
|
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
|
|
|
@ -418,8 +462,9 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
|
|
|
|
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
|
|
|
|
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
|
|
|
|
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
|
|
|
|
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
|
|
|
|
|
|
|
|
#endif
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#if CUDNN_VERSION >= 7201
|
|
|
|
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
|
|
|
|
// for train
|
|
|
|
// for train
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
// This interface is used when the input/output is padded.
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
|
|
|
@ -452,7 +497,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
|
|
|
// MIOPEN do not support double
|
|
|
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
|
|
|
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
|
|
|
|
|
|
|
|
#else
|
|
|
|
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
|
|
|
|
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
|
|
|
|
ops::CudnnLSTMGPUKernel<double>);
|
|
|
|
ops::CudnnLSTMGPUKernel<double>);
|
|
|
|
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
|
|
|
|
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
|
|
|
|
ops::CudnnLSTMGPUGradKernel<double>);
|
|
|
|
ops::CudnnLSTMGPUGradKernel<double>);
|
|
|
|
|
|
|
|
#endif
|
|
|
|