[ROCM] update fluid operators for rocm (part7), test=develop (#31307)

test_model_benchmark
Qi Li 4 years ago committed by GitHub
parent db50fb6766
commit 3b9db17199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,9 +73,11 @@ register_operators(EXCLUDES py_func_op warpctc_op dgc_op lstm_op run_program_op
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
if (WITH_GPU) if (WITH_GPU OR WITH_ROCM)
if(WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) elseif(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
else() else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
@ -108,7 +110,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_fun
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
if (WITH_GPU) if (WITH_GPU OR WITH_ROCM)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
endif() endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
@ -139,9 +141,12 @@ cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_t
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator)
if (WITH_GPU) if (WITH_GPU)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator)
nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3) nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3)
elseif(WITH_ROCM)
hip_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator)
hip_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3)
else() else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif() endif()

@ -11,7 +11,7 @@
#include "paddle/fluid/operators/bmm_op.h" #include "paddle/fluid/operators/bmm_op.h"
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
bmm, ops::BmmKernel<paddle::platform::CUDADeviceContext, float>, bmm, ops::BmmKernel<paddle::platform::CUDADeviceContext, float>,

@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
@ -164,3 +167,5 @@ REGISTER_OP_CUDA_KERNEL(
cholesky_grad, cholesky_grad,
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, float>, ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, double>); ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif // not PADDLE_WITH_HIP

@ -25,7 +25,7 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
using platform::Transform; using platform::Transform;
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename UnaryOperation> template <typename T, typename UnaryOperation>
__global__ void ClipCudaKernel(const T* input, T* out, int num, __global__ void ClipCudaKernel(const T* input, T* out, int num,
UnaryOperation op) { UnaryOperation op) {
@ -105,7 +105,7 @@ class ClipKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int64_t numel = x->numel(); int64_t numel = x->numel();
if (platform::is_gpu_place(context.GetPlace())) { if (platform::is_gpu_place(context.GetPlace())) {
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
int threads = 256; int threads = 256;
int blocks = (numel + threads - 1) / threads; int blocks = (numel + threads - 1) / threads;
ClipCudaKernel<T, ClipFunctor<T>><<< ClipCudaKernel<T, ClipFunctor<T>><<<

@ -289,7 +289,7 @@ REGISTER_OP_CPU_KERNEL(
ops::CoalesceTensorOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CoalesceTensorOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CoalesceTensorOpKernel<paddle::platform::CPUDeviceContext, double>); ops::CoalesceTensorOpKernel<paddle::platform::CPUDeviceContext, double>);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
coalesce_tensor, coalesce_tensor,
ops::CoalesceTensorOpKernel<paddle::platform::CUDADeviceContext, ops::CoalesceTensorOpKernel<paddle::platform::CUDADeviceContext,

@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not supported yet
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
@ -480,3 +483,5 @@ REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel<float>,
ops::CorrelationCUDAKernel<double>); ops::CorrelationCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>, REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>,
ops::CorrelationCUDAGradKernel<double>); ops::CorrelationCUDAGradKernel<double>);
#endif // not PADDLE_WITH_HIP

@ -14,9 +14,14 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/miopen_lstm_cache.h"
#endif
namespace paddle { namespace paddle {
namespace platform { namespace platform {
@ -54,7 +59,7 @@ int size_sum(const std::vector<const Tensor *> &weight_list) {
} }
template <typename T> template <typename T>
void weight_to_tensor(const platform::Place &place, cudaStream_t stream, void weight_to_tensor(const platform::Place &place, gpuStream_t stream,
const std::vector<const Tensor *> &weight_list, const std::vector<const Tensor *> &weight_list,
Tensor *weight) { Tensor *weight) {
auto weight_data = weight->data<T>(); auto weight_data = weight->data<T>();
@ -72,7 +77,7 @@ void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
} }
template <typename T> template <typename T>
void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream,
std::vector<Tensor *> *weight_grad, std::vector<Tensor *> *weight_grad,
const std::vector<const Tensor *> &weight_input, const std::vector<const Tensor *> &weight_input,
const Tensor *weight) { const Tensor *weight) {
@ -92,23 +97,36 @@ void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
} }
template <typename T> template <typename T>
#ifdef PADDLE_WITH_HIP
void LSTMInferece(const bool &has_seq_length, const miopenHandle_t &handle,
#else
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle, void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
#endif
const int &seq_length, ScopedRNNBase *rnn, const T *x_data, const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
const T *init_h_data, const T *init_c_data, const T *w_data, const T *init_h_data, const T *init_c_data, const T *w_data,
T *out_data, T *last_h_data, T *last_c_data, T *out_data, T *last_h_data, T *last_c_data,
framework::Tensor *workspace_data, framework::Tensor *workspace_data,
const size_t &workspace_size) { const size_t &workspace_size) {
if (!has_seq_length) { if (!has_seq_length) {
// for inference // for inference
// This interface is used when the input/output is unpadded. // This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardInference(
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
workspace_data->data<uint8_t>(), workspace_size));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data, handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data, rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
rnn->weight_desc(), w_data, rnn->y_descs(), out_data, rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data, rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
workspace_data->data<uint8_t>(), workspace_size)); workspace_data->data<uint8_t>(), workspace_size));
#endif
} else { } else {
#if CUDNN_VERSION >= 7201 #if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for inference // for inference
// This interface is used when the input/output is padded. // This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
@ -256,8 +274,17 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
last_c_data, &workspace_data_, workspace_size); last_c_data, &workspace_data_, workspace_size);
} else { } else {
if (!has_seq_length) { if (!has_seq_length) {
// for train // for train
// This interface is used when the input/output is unpadded. // This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data, handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data, rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
@ -265,8 +292,9 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data, rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
workspace_data_.data<uint8_t>(), workspace_size, reserve_data, workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
reserve_size)); reserve_size));
#endif
} else { } else {
#if CUDNN_VERSION >= 7201 #if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for train // for train
// This interface is used when the input/output is padded. // This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
@ -403,7 +431,23 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
const uint8_t *reserve_data = reserve->data<uint8_t>(); const uint8_t *reserve_data = reserve->data<uint8_t>();
if (!has_seq_length) { if (!has_seq_length) {
// This interface is used when the input/output is unpadded. // This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
rnn.weight_desc(), weight_grad_data, workspace_data_.data<uint8_t>(),
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data, handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data, rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
@ -418,8 +462,9 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(), rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(), workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size)); weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
#endif
} else { } else {
#if CUDNN_VERSION >= 7201 #if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
// for train // for train
// This interface is used when the input/output is padded. // This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
@ -452,7 +497,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
#else
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>, REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
ops::CudnnLSTMGPUKernel<double>); ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>, REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
ops::CudnnLSTMGPUGradKernel<double>); ops::CudnnLSTMGPUGradKernel<double>);
#endif

@ -16,7 +16,13 @@ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/reverse.h> #include <thrust/reverse.h>
#include <thrust/scan.h> #include <thrust/scan.h>
#include "cub/cub.cuh" #ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/cum_op.h" #include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_config.h" #include "paddle/fluid/platform/gpu_launch_config.h"

@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/data_norm_op.h" #include "paddle/fluid/operators/data_norm_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
@ -174,7 +174,7 @@ class DataNormGradKernel<platform::CUDADeviceContext, T>
d_batch_sum, d_batch_square_sum); d_batch_sum, d_batch_square_sum);
if (need_sync_stats) { if (need_sync_stats) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_size), reinterpret_cast<const void *>(d_batch_size),
@ -188,7 +188,11 @@ class DataNormGradKernel<platform::CUDADeviceContext, T>
reinterpret_cast<const void *>(d_batch_square_sum), reinterpret_cast<const void *>(d_batch_square_sum),
reinterpret_cast<void *>(d_batch_square_sum), C, reinterpret_cast<void *>(d_batch_square_sum), C,
platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream)); platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream));
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU, and need_sync_stats connot be " "PaddlePaddle should compile with GPU, and need_sync_stats connot be "

@ -100,7 +100,7 @@ class DiagEmbedKernel : public framework::OpKernel<T> {
strides.push_back(stride[dim1_] + stride[dim2_]); strides.push_back(stride[dim1_] + stride[dim2_]);
const auto dims = vectorize(input->dims()); const auto dims = vectorize(input->dims());
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> dims_vec(dims); thrust::device_vector<int64_t> dims_vec(dims);
const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data()); const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data());
thrust::device_vector<int64_t> strides_vec(strides); thrust::device_vector<int64_t> strides_vec(strides);

@ -45,7 +45,7 @@ struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
const Tensor* tensor_dout, Tensor* tensor_dx, const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy, Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) { if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout); auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
@ -249,7 +249,7 @@ class DotKernel : public framework::OpKernel<T> {
auto* tensor_out = ctx.Output<Tensor>("Out"); auto* tensor_out = ctx.Output<Tensor>("Out");
tensor_out->mutable_data<T>(ctx.GetPlace()); tensor_out->mutable_data<T>(ctx.GetPlace());
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_out->dims().size()) { if (1 == tensor_out->dims().size()) {
auto out = framework::EigenScalar<T>::From(*tensor_out); auto out = framework::EigenScalar<T>::From(*tensor_out);
auto x = framework::EigenVector<T>::Flatten(*tensor_x); auto x = framework::EigenVector<T>::Flatten(*tensor_x);

@ -11,8 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda.h> #include <cuda.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h> #include <thrust/random.h>
@ -21,7 +30,6 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/dynload/curand.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
@ -32,15 +40,24 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src, const float dropout_prob, const T* src,
MaskType* mask_data, T* dst, MaskType* mask_data, T* dst,
bool is_upscale_in_train, uint64_t increment) { bool is_upscale_in_train, uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, increment, &state);
#else
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state); curand_init(seed, idx, increment, &state);
#endif
MaskType mask; MaskType mask;
T dest; T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) { for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx]; T s = src[idx];
#ifdef PADDLE_WITH_HIP
if (hiprand_uniform(&state) < dropout_prob) {
#else
if (curand_uniform(&state) < dropout_prob) { if (curand_uniform(&state) < dropout_prob) {
#endif
mask = 0; mask = 0;
dest = 0; dest = 0;
} else { } else {
@ -62,9 +79,15 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const T* src, MaskType* mask_data, const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train, T* dst, bool is_upscale_in_train,
uint64_t increment) { uint64_t increment) {
#ifdef PADDLE_WITH_HIP
int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, increment, &state);
#else
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state); curand_init(seed, idx, increment, &state);
#endif
MaskType mask; MaskType mask;
T dest; T dest;
@ -75,7 +98,11 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
T src_vec[VecSize]; T src_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&src_vec); LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
*value = *reinterpret_cast<const LoadT*>(&src[i]); *value = *reinterpret_cast<const LoadT*>(&src[i]);
#ifdef PADDLE_WITH_HIP
float4 rand = hiprand_uniform4(&state);
#else
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
#endif
T dest_vec[VecSize]; T dest_vec[VecSize];
MaskType mask_vec[VecSize]; MaskType mask_vec[VecSize];
@ -131,10 +158,17 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto* x_data = x->data<T>(); auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
if (dropout_prob == 1.0f) { if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
mask_data, 0, x_numel * sizeof(*mask_data), stream)); mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
return; return;
} }
@ -180,6 +214,20 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
increment = offset; increment = offset;
} }
#ifdef __HIPCC__
if (vec_size == 4 && size % 4 == 0) {
hipLaunchKernelGGL(
HIP_KERNEL_NAME(VectorizedRandomGenerator<T, uint8_t, 4>),
config.block_per_grid, config.thread_per_block, 0, stream, size,
seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator<T, uint8_t>),
config.block_per_grid, config.thread_per_block, 0,
stream, size, seed_data, dropout_prob, x_data,
mask_data, y_data, upscale_in_train, increment);
}
#else
if (vec_size == 4 && size % 4 == 0) { if (vec_size == 4 && size % 4 == 0) {
VectorizedRandomGenerator< VectorizedRandomGenerator<
T, uint8_t, T, uint8_t,
@ -192,7 +240,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
size, seed_data, dropout_prob, x_data, mask_data, y_data, size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment); upscale_in_train, increment);
} }
#endif
} else { } else {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);

@ -42,7 +42,7 @@ inline int VectorizedSize(const T* pointer) {
return 1; return 1;
} }
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename MaskType, int VecSize> template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
const T factor, const int64_t size, const T factor, const int64_t size,
@ -186,7 +186,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
int vec_size = VectorizedSize<T>(grad_y->data<T>()); int vec_size = VectorizedSize<T>(grad_y->data<T>());
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) { size % 4 == 0) {
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob)); auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(

@ -162,7 +162,11 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
int grid = cout; int grid = cout;
int max_threads = 1024; int max_threads = 1024;
#ifdef PADDLE_WITH_HIP
hipMemset(out_abs_max, 0, sizeof(T) * cout);
#else
cudaMemset(out_abs_max, 0, sizeof(T) * cout); cudaMemset(out_abs_max, 0, sizeof(T) * cout);
#endif
for (int i = 0; i < cin / max_threads; i++) { for (int i = 0; i < cin / max_threads; i++) {
int block = max_threads; int block = max_threads;

@ -65,7 +65,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
out, static_cast<T>(value)); out, static_cast<T>(value));
} }
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!cpu_place) { if (!cpu_place) {
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, T> functor;
out->mutable_data(ctx.GetPlace(), data_type); out->mutable_data(ctx.GetPlace(), data_type);

@ -121,7 +121,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value)); tensor, static_cast<T>(value));
} else if (actual_place == 1) { } else if (actual_place == 1) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(ctx.GetPlace(), data_type); tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
@ -131,7 +131,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
#endif #endif
} else if (actual_place == 2) { } else if (actual_place == 2) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(platform::CUDAPinnedPlace(), data_type); tensor->mutable_data(platform::CUDAPinnedPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),

@ -31,7 +31,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T>
using Vector = framework::Vector<T>; using Vector = framework::Vector<T>;
#else #else

@ -54,7 +54,8 @@ struct GeluFunctor {
} }
} else { } else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
auto x_data = x.data(); auto x_data = x.data();
auto out_data = out.data(); auto out_data = out.data();
int n = std::min(x.size(), out.size()); int n = std::min(x.size(), out.size());
@ -121,7 +122,8 @@ struct GeluGradFunctor {
} }
} else { } else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
auto x_data = x.data(); auto x_data = x.data();
auto dx_data = dx.data(); auto dx_data = dx.data();
auto dout_data = dout.data(); auto dout_data = dout.data();

@ -107,7 +107,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float,
ops::GetTensorFromSelectedRowsKernel, int64_t, ops::GetTensorFromSelectedRowsKernel, int64_t,
ops::GetTensorFromSelectedRowsKernel); ops::GetTensorFromSelectedRowsKernel);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float, REGISTER_OP_CUDA_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float,
ops::GetTensorFromSelectedRowsKernel, double, ops::GetTensorFromSelectedRowsKernel, double,
ops::GetTensorFromSelectedRowsKernel, int, ops::GetTensorFromSelectedRowsKernel, int,

Loading…
Cancel
Save