From 946dbdae8c411b4235abcb9d38931befc8fb3e4c Mon Sep 17 00:00:00 2001 From: Qi Li Date: Wed, 3 Mar 2021 11:21:29 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part6), test=develop (#31301) --- paddle/fluid/operators/activation_cudnn.cu.cc | 4 + .../fluid/operators/activation_cudnn_op.cu.cc | 80 +++++++- paddle/fluid/operators/activation_op.cc | 3 - paddle/fluid/operators/affine_channel_op.cu | 8 + .../operators/affine_grid_cudnn_op.cu.cc | 5 + paddle/fluid/operators/affine_grid_op.cc | 7 +- paddle/fluid/operators/allclose_op.cu | 5 +- .../fluid/operators/arg_min_max_op_base.cu.h | 10 +- paddle/fluid/operators/argsort_op.cu | 18 +- paddle/fluid/operators/batch_fc_op.cu | 5 +- paddle/fluid/operators/batch_norm_op.cu | 179 +++++++++++++++++- paddle/fluid/operators/bce_loss_op.cu | 1 - .../operators/math/sequence_padding_test.cc | 2 +- .../operators/math/sequence_pooling_test.cc | 2 +- paddle/fluid/operators/math/sequence_scale.cu | 8 + paddle/fluid/operators/math/softmax.cu | 39 +++- paddle/fluid/operators/math/softmax.h | 2 +- paddle/fluid/operators/pool_op.h | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 6 +- 19 files changed, 350 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/activation_cudnn.cu.cc b/paddle/fluid/operators/activation_cudnn.cu.cc index 7f8ecc1df0..38499783eb 100644 --- a/paddle/fluid/operators/activation_cudnn.cu.cc +++ b/paddle/fluid/operators/activation_cudnn.cu.cc @@ -14,7 +14,11 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_desc.h" +#else #include "paddle/fluid/platform/cudnn_desc.h" +#endif namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/activation_cudnn_op.cu.cc b/paddle/fluid/operators/activation_cudnn_op.cu.cc index 26ad09cc26..b197d3511f 100644 --- a/paddle/fluid/operators/activation_cudnn_op.cu.cc +++ b/paddle/fluid/operators/activation_cudnn_op.cu.cc @@ -14,7 +14,11 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_desc.h" +#else #include "paddle/fluid/platform/cudnn_desc.h" +#endif namespace paddle { namespace platform { @@ -29,35 +33,71 @@ using platform::ActivationDescriptor; using platform::TensorDescriptor; using platform::CUDADeviceContext; +#ifdef PADDLE_WITH_HIP +#define GPUDNN_ACTIVATION_RELU miopenActivationRELU +#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU +#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC +#define GPUDNN_ACTIVATION_TANH miopenActivationTANH +#else +#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU +#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU +#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID +#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH +#endif + template struct CudnnActivationFunctor { using ELEMENT_TYPE = T; +#ifdef PADDLE_WITH_HIP + CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c, + const miopenActivationMode_t& m) + : ctx_(ctx), coef_(c), mode_(m) {} +#else CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c, const cudnnActivationMode_t& m) : ctx_(ctx), coef_(c), mode_(m) {} +#endif void operator()(const Tensor& x, Tensor* out) { ActivationDescriptor act_desc; act_desc.set(mode_, coef_); TensorDescriptor x_desc, out_desc; x_desc.set(x); out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation")); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward( + ctx_.cudnn_handle(), act_desc.desc(), + platform::CudnnDataType::kOne(), x_desc.desc(), x.data(), + platform::CudnnDataType::kZero(), out_desc.desc(), + out->mutable_data(ctx_.GetPlace()))); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationForward( ctx_.cudnn_handle(), act_desc.desc(), platform::CudnnDataType::kOne(), x_desc.desc(), x.data(), platform::CudnnDataType::kZero(), out_desc.desc(), out->mutable_data(ctx_.GetPlace()))); +#endif } const CUDADeviceContext& ctx_; const T coef_; +#ifdef PADDLE_WITH_HIP + const miopenActivationMode_t mode_; +#else const cudnnActivationMode_t mode_; +#endif }; template struct CudnnActivationGradFunctor { using ELEMENT_TYPE = T; +#ifdef PADDLE_WITH_HIP + CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c, + const miopenActivationMode_t& m) + : ctx_(ctx), coef_(c), mode_(m) {} +#else CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c, const cudnnActivationMode_t& m) : ctx_(ctx), coef_(c), mode_(m) {} +#endif void operator()(const Tensor& x, const Tensor& out, const Tensor dout, Tensor* dx) { ActivationDescriptor act_desc; @@ -67,27 +107,40 @@ struct CudnnActivationGradFunctor { out_desc.set(out); dout_desc.set(dout); dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad")); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationBackward( + ctx_.cudnn_handle(), act_desc.desc(), + platform::CudnnDataType::kOne(), out_desc.desc(), out.data(), + dout_desc.desc(), dout.data(), x_desc.desc(), x.data(), + platform::CudnnDataType::kZero(), dx_desc.desc(), + dx->mutable_data(ctx_.GetPlace()))); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationBackward( ctx_.cudnn_handle(), act_desc.desc(), platform::CudnnDataType::kOne(), out_desc.desc(), out.data(), dout_desc.desc(), dout.data(), x_desc.desc(), x.data(), platform::CudnnDataType::kZero(), dx_desc.desc(), dx->mutable_data(ctx_.GetPlace()))); +#endif } const CUDADeviceContext& ctx_; const T coef_; +#ifdef PADDLE_WITH_HIP + const miopenActivationMode_t mode_; +#else const cudnnActivationMode_t mode_; +#endif }; template struct CudnnReluFunctor : public CudnnActivationFunctor { explicit CudnnReluFunctor(const CUDADeviceContext& ctx) - : CudnnActivationFunctor(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} + : CudnnActivationFunctor(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {} }; template struct CudnnReluGradFunctor : public CudnnActivationGradFunctor { explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) - : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} + : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {} static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; @@ -95,13 +148,13 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor { template struct CudnnRelu6Functor : public CudnnActivationFunctor { explicit CudnnRelu6Functor(const CUDADeviceContext& ctx) - : CudnnActivationFunctor(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {} + : CudnnActivationFunctor(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {} }; template struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor { explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx) - : CudnnActivationGradFunctor(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) { - } + : CudnnActivationGradFunctor(ctx, 6.0, + GPUDNN_ACTIVATION_CLIPPED_RELU) {} static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; @@ -109,12 +162,12 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor { template struct CudnnSigmoidFunctor : public CudnnActivationFunctor { explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx) - : CudnnActivationFunctor(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} + : CudnnActivationFunctor(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {} }; template struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor { explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) - : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} + : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {} static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; @@ -122,12 +175,12 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor { template struct CudnnTanhFunctor : public CudnnActivationFunctor { explicit CudnnTanhFunctor(const CUDADeviceContext& ctx) - : CudnnActivationFunctor(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} + : CudnnActivationFunctor(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {} }; template struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor { explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) - : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} + : CudnnActivationGradFunctor(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {} static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; @@ -183,6 +236,14 @@ namespace ops = paddle::operators; __macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \ __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor) +#ifdef PADDLE_WITH_HIP +#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \ + ops::CudnnActivationKernel>); \ + REGISTER_OP_KERNEL( \ + act_type##_grad, CUDNN, plat::CUDAPlace, \ + ops::CudnnActivationGradKernel>); +#else #define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \ ops::CudnnActivationKernel>, \ @@ -191,5 +252,6 @@ namespace ops = paddle::operators; act_type##_grad, CUDNN, plat::CUDAPlace, \ ops::CudnnActivationGradKernel>, \ ops::CudnnActivationGradKernel>); +#endif FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 785d6daaec..94f2eb3672 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -24,9 +24,6 @@ limitations under the License. */ #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/platform/port.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/cudnn_helper.h" -#endif DECLARE_bool(use_mkldnn); diff --git a/paddle/fluid/operators/affine_channel_op.cu b/paddle/fluid/operators/affine_channel_op.cu index 5e59807121..cddc288c24 100644 --- a/paddle/fluid/operators/affine_channel_op.cu +++ b/paddle/fluid/operators/affine_channel_op.cu @@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif + +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/cuda_primitives.h" diff --git a/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc b/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc index c09f71f46c..b8ce52387b 100644 --- a/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc +++ b/paddle/fluid/operators/affine_grid_cudnn_op.cu.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifndef PADDLE_WITH_HIP +// HIP not support cudnnSpatialTfGridGeneratorForward + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -121,3 +124,5 @@ REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNAffineGridGradOpKernel, paddle::operators::CUDNNAffineGridGradOpKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 675baa6768..7be9bced13 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -21,6 +21,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace operators { @@ -109,7 +112,7 @@ class AffineGridOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library = framework::LibraryType::kCUDNN; } @@ -226,7 +229,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } diff --git a/paddle/fluid/operators/allclose_op.cu b/paddle/fluid/operators/allclose_op.cu index f98fe75cd6..173e24b2f1 100644 --- a/paddle/fluid/operators/allclose_op.cu +++ b/paddle/fluid/operators/allclose_op.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/allclose_op.h" @@ -67,7 +66,11 @@ struct AllcloseFunctor { int block = 1024; int grid = (block - 1 + num) / block; grid = (grid > block) ? block : grid; +#ifdef PADDLE_WITH_HIP + hipMemset(out_data, true, sizeof(bool)); +#else cudaMemset(out_data, true, sizeof(bool)); +#endif AllcloseCUDAKernel<<>>( in_data, other_data, rtol, atol, equal_nan, num, out_data); } diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h index 3e549428b0..b19ba1e159 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -14,9 +14,15 @@ limitations under the License. */ #pragma once -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) -#include +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include #include #include diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 7fc2a92b7d..f50d5e619e 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -16,13 +16,28 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/argsort_op.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#ifdef __HIPCC__ +namespace rocprim { +namespace detail { +template <> +struct radix_key_codec_base + : radix_key_codec_integral {}; +} // namespace detail +} // namespace rocprim +#else // set cub base traits in order to handle float16 namespace cub { template <> @@ -30,6 +45,7 @@ struct NumericTraits : BaseTraits {}; } // namespace cub +#endif namespace paddle { namespace operators { @@ -139,7 +155,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, cub::CountingInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); - cudaError_t err; + gpuError_t err; if (descending) { err = cub::DeviceSegmentedRadixSort::SortPairsDescending( nullptr, temp_storage_bytes, inp, sorted_out_ptr, diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index 9a39306cca..b686c766e0 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/batch_fc_op.h" @@ -42,7 +41,7 @@ __global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num, } template -void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num, +void add_bias(gpuStream_t stream, T* data, int slot_pairs_num, int ins_num, int out_dim, const T* bias) { add_bias_kernel<<>>(data, slot_pairs_num, @@ -65,7 +64,7 @@ __global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num, } template -void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num, +void add_bias_grad(gpuStream_t stream, const T* dout_data, int slot_pairs_num, int ins_num, int out_dim, T* db_data) { add_bias_grad_kernel<<>>(dout_data, slot_pairs_num, ins_num, diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index ae9cf2838b..444c24b826 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -16,12 +16,17 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/norm_utils.cu.h" -#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" DECLARE_bool(cudnn_batchnorm_spatial_persistent); @@ -73,6 +78,11 @@ class BatchNormKernel ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); auto dtype = platform::CudnnDataType::type; + +#ifdef PADDLE_WITH_HIP + // HIP do not support compute format of NHWC + auto compute_format = DataLayout::kNCHW; +#else const bool fast_nhwc_batch_norm = test_mode || (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent); @@ -81,6 +91,7 @@ class BatchNormKernel fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; +#endif Tensor transformed_x(x->type()); Tensor transformed_y(y->type()); @@ -98,7 +109,17 @@ class BatchNormKernel transformed_y.ShareDataWith(*y); } - // ------------------- cudnn descriptors --------------------- +// ------------------- cudnn descriptors --------------------- +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t data_desc_; + miopenTensorDescriptor_t bn_param_desc_; + miopenBatchNormMode_t mode_; + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_)); +#else cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t bn_param_desc_; cudnnBatchNormMode_t mode_; @@ -107,6 +128,7 @@ class BatchNormKernel platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); +#endif if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " @@ -114,7 +136,10 @@ class BatchNormKernel << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); -#if CUDNN_VERSION_MIN(7, 0, 1) + +#ifdef PADDLE_WITH_HIP + mode_ = miopenBNSpatial; +#elif CUDNN_VERSION_MIN(7, 0, 1) if (FLAGS_cudnn_batchnorm_spatial_persistent) { mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; } else { @@ -134,6 +159,17 @@ class BatchNormKernel dims = {N, C, H, W, D}; strides = {H * W * D * C, 1, W * D * C, D * C, C}; } + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, const_cast(dims.data()), + const_cast(strides.data()))); + // Note: PERSISTENT not implemented for inference + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); @@ -142,6 +178,7 @@ class BatchNormKernel platform::dynload::cudnnDeriveBNTensorDescriptor( bn_param_desc_, data_desc_, test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_)); +#endif const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -188,6 +225,30 @@ class BatchNormKernel "variance is [%d], the dimensions of variance is [%s].", C, est_var->dims()[0], est_var->dims())); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenBatchNormalizationForwardInference( + handle, miopenBNSpatial, + const_cast( + static_cast(CudnnDataType::kOne())), + const_cast( + static_cast(CudnnDataType::kZero())), + data_desc_, + static_cast(transformed_x.template data()), + data_desc_, + static_cast( + transformed_y.template mutable_data(ctx.GetPlace())), + bn_param_desc_, + const_cast(static_cast( + scale->template data>())), + const_cast(static_cast( + bias->template data>())), + const_cast(static_cast( + est_mean->template data>())), + const_cast(static_cast( + est_var->template data>())), + epsilon)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationForwardInference( handle, @@ -200,6 +261,7 @@ class BatchNormKernel bias->template data>(), est_mean->template data>(), est_var->template data>(), epsilon)); +#endif } else { // if MomentumTensor is set, use MomentumTensor value, momentum // is only used in this training branch @@ -302,6 +364,36 @@ class BatchNormKernel reserve_space_size)); #endif // CUDNN_VERSION_MIN(7, 4, 1) if (!called) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenBatchNormalizationForwardTraining( + handle, mode_, const_cast(static_cast( + CudnnDataType::kOne())), + const_cast( + static_cast(CudnnDataType::kZero())), + data_desc_, + static_cast(transformed_x.template data()), + data_desc_, + static_cast( + transformed_y.template mutable_data(ctx.GetPlace())), + bn_param_desc_, + const_cast(static_cast( + scale->template data>())), + const_cast(static_cast( + bias->template data>())), + this_factor, + static_cast( + mean_out->template mutable_data>( + ctx.GetPlace())), + static_cast(variance_out->template mutable_data< + BatchNormParamType>(ctx.GetPlace())), + epsilon, + static_cast( + saved_mean->template mutable_data>( + ctx.GetPlace())), + static_cast(saved_variance->template mutable_data< + BatchNormParamType>(ctx.GetPlace())))); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationForwardTraining( handle, mode_, CudnnDataType::kOne(), @@ -319,6 +411,7 @@ class BatchNormKernel ctx.GetPlace()), saved_variance->template mutable_data>( ctx.GetPlace()))); +#endif } } } @@ -329,11 +422,19 @@ class BatchNormKernel TransToChannelLast( ctx, &transformed_y, y); } +#ifdef PADDLE_WITH_HIP + // clean when exit. + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_)); +#else // clean when exit. PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#endif } }; @@ -416,7 +517,7 @@ class InplaceHelper { const BatchNormParamType *mean, const BatchNormParamType *variance, double epsilon, int C, int M, const int num, const T *y, int grid2, const int block, - const cudaStream_t &stream) { + const gpuStream_t &stream) { PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( "X and Y should be inplaced in inplace mode")); KeBNRestoreData<<>>( @@ -566,6 +667,10 @@ class BatchNormGradKernel auto dtype = platform::CudnnDataType::type; const auto *reserve_space = ctx.Input("ReserveSpace"); +#ifdef PADDLE_WITH_HIP + // HIP do not support compute format of NHWC + auto compute_format = DataLayout::kNCHW; +#else const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent && reserve_space != nullptr; @@ -573,6 +678,7 @@ class BatchNormGradKernel fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; +#endif Tensor transformed_x(x->type()); Tensor transformed_d_y(d_y->type()); @@ -626,7 +732,17 @@ class BatchNormGradKernel return; } - // ------------------- cudnn descriptors --------------------- +// ------------------- cudnn descriptors --------------------- +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t data_desc_; + miopenTensorDescriptor_t bn_param_desc_; + miopenBatchNormMode_t mode_; + + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_)); +#else cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t bn_param_desc_; cudnnBatchNormMode_t mode_; @@ -635,13 +751,16 @@ class BatchNormGradKernel platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); +#endif if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); -#if CUDNN_VERSION_MIN(7, 0, 1) +#ifdef PADDLE_WITH_HIP + mode_ = miopenBNSpatial; +#elif CUDNN_VERSION_MIN(7, 0, 1) if (FLAGS_cudnn_batchnorm_spatial_persistent) { mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; } else { @@ -651,12 +770,22 @@ class BatchNormGradKernel mode_ = CUDNN_BATCHNORM_SPATIAL; #endif // CUDNN_VERSION_MIN(7, 0, 1) +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, const_cast(dims.data()), + const_cast(strides.data()))); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_, + data_desc_, mode_)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_, data_desc_, mode_)); +#endif const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); @@ -741,6 +870,22 @@ class BatchNormGradKernel /*reserveSpaceSizeInBytes=*/reserve_space_size)); #endif // CUDNN_VERSION_MIN(7, 4, 1) if (!called) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenBatchNormalizationBackward( + dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), + CudnnDataType::kZero(), CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, + transformed_x.template data(), data_desc_, + transformed_d_y.template data(), data_desc_, + transformed_d_x.template mutable_data(ctx.GetPlace()), + bn_param_desc_, scale->template data>(), + d_scale->template mutable_data>( + ctx.GetPlace()), + d_bias->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean_data, saved_var_data)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationBackward( dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), @@ -755,6 +900,7 @@ class BatchNormGradKernel d_bias->template mutable_data>( ctx.GetPlace()), epsilon, saved_mean_data, saved_var_data)); +#endif } if (data_layout == DataLayout::kNHWC && @@ -784,11 +930,19 @@ class BatchNormGradKernel } } +#ifdef PADDLE_WITH_HIP + // clean when exit. + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_)); +#else // clean when exit. PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#endif } else { const auto *running_mean = ctx.Input("Mean"); const auto *running_var = ctx.Input("Variance"); @@ -886,6 +1040,18 @@ class BatchNormDoubleGradKernel namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL( + batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel); +REGISTER_OP_CUDA_KERNEL( + batch_norm_grad, ops::BatchNormGradKernel, + ops::BatchNormGradKernel); +REGISTER_OP_CUDA_KERNEL( + batch_norm_grad_grad, + ops::BatchNormDoubleGradKernel); +#else REGISTER_OP_CUDA_KERNEL( batch_norm, ops::BatchNormKernel, ops::BatchNormKernel, @@ -898,3 +1064,4 @@ REGISTER_OP_CUDA_KERNEL( batch_norm_grad_grad, ops::BatchNormDoubleGradKernel, ops::BatchNormDoubleGradKernel); +#endif diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 1a967c5738..99153101fc 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include -#include "cub/cub.cuh" #include "paddle/fluid/operators/bce_loss_op.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/cuda_primitives.h" diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index 1f7e9f9ae0..ea31b10c55 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -105,7 +105,7 @@ TEST(Seq2BatchPadding, CPU) { 128); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(SequencePadding, CUDA) { auto place = paddle::platform::CUDAPlace(0); auto *context = static_cast( diff --git a/paddle/fluid/operators/math/sequence_pooling_test.cc b/paddle/fluid/operators/math/sequence_pooling_test.cc index 4ece42ab80..775d8029bf 100644 --- a/paddle/fluid/operators/math/sequence_pooling_test.cc +++ b/paddle/fluid/operators/math/sequence_pooling_test.cc @@ -123,7 +123,7 @@ TEST(SequencePoolingGrad, CPU_SUM) { lod2, 128); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(SequencePoolingGrad, CUDA_SUM) { auto place = paddle::platform::CUDAPlace(0); auto *context = static_cast( diff --git a/paddle/fluid/operators/math/sequence_scale.cu b/paddle/fluid/operators/math/sequence_scale.cu index 4a952afe15..5578f1f013 100644 --- a/paddle/fluid/operators/math/sequence_scale.cu +++ b/paddle/fluid/operators/math/sequence_scale.cu @@ -44,10 +44,18 @@ class ScaleLoDTensorFunctor { framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); T* seq_data = seq->mutable_data(context.GetPlace()); +#ifdef PADDLE_WITH_HIP + hipLaunchKernelGGL( + HIP_KERNEL_NAME(SequenceScaleKernel), + dim3(num_seq), dim3(PADDLE_CUDA_NUM_THREADS), 0, context.stream(), + seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()), + scales, seq_width); +#else SequenceScaleKernel<<< num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()), scales, seq_width); +#endif } }; diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 742dc7f444..879e367281 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -16,7 +16,11 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax_impl.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else #include "paddle/fluid/platform/cudnn_helper.h" +#endif namespace paddle { namespace operators { @@ -45,6 +49,16 @@ void SoftmaxCUDNNFunctor::operator()( if (cudnn_tensor_dims.size() <= 2) { cudnn_tensor_dims.resize(4, 1); } +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t cudnn_x_desc = + xDesc.descriptor(layout, cudnn_tensor_dims); + miopenTensorDescriptor_t cudnn_y_desc = + xDesc.descriptor(layout, cudnn_tensor_dims); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward( + context.cudnn_handle(), CudnnDataType::kOne(), cudnn_x_desc, + X->data(), CudnnDataType::kZero(), cudnn_y_desc, + Y->mutable_data(context.GetPlace()))); +#else cudnnTensorDescriptor_t cudnn_x_desc = xDesc.descriptor(layout, cudnn_tensor_dims); cudnnTensorDescriptor_t cudnn_y_desc = @@ -54,6 +68,7 @@ void SoftmaxCUDNNFunctor::operator()( CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType::kOne(), cudnn_x_desc, X->data(), CudnnDataType::kZero(), cudnn_y_desc, Y->mutable_data(context.GetPlace()))); +#endif } template @@ -74,6 +89,19 @@ void SoftmaxGradCUDNNFunctor::operator()( if (cudnn_tensor_dims.size() <= 2) { cudnn_tensor_dims.resize(4, 1); } +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t cudnn_y_desc = + yDesc.descriptor(layout, cudnn_tensor_dims); + miopenTensorDescriptor_t cudnn_xgrad_desc = + dxDesc.descriptor(layout, cudnn_tensor_dims); + miopenTensorDescriptor_t cudnn_ygrad_desc = + dyDesc.descriptor(layout, cudnn_tensor_dims); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward( + context.cudnn_handle(), CudnnDataType::kOne(), cudnn_y_desc, + Y->data(), cudnn_ygrad_desc, YGrad->data(), + CudnnDataType::kZero(), cudnn_xgrad_desc, + XGrad->mutable_data(context.GetPlace()))); +#else cudnnTensorDescriptor_t cudnn_y_desc = yDesc.descriptor(layout, cudnn_tensor_dims); cudnnTensorDescriptor_t cudnn_xgrad_desc = @@ -86,15 +114,20 @@ void SoftmaxGradCUDNNFunctor::operator()( Y->data(), cudnn_ygrad_desc, YGrad->data(), CudnnDataType::kZero(), cudnn_xgrad_desc, XGrad->mutable_data(context.GetPlace()))); +#endif } -template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; -template class SoftmaxCUDNNFunctor; +template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; -template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; +// MIOPEN do not support double +#ifndef PADDLE_WITH_HIP +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +#endif + template class SoftmaxFunctor; template class SoftmaxFunctor class SoftmaxCUDNNFunctor { public: diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 4bb0e1d582..a738816c40 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/pooling.h" -#ifdef __NVCC__ +#if defined(__HIPCC__) || defined(__NVCC__) #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #endif diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 8bb0779bc0..f5c58eb451 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -278,6 +278,9 @@ class OpTest(unittest.TestCase): def is_mkldnn_op_test(): return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True + def is_rocm_op_test(): + return core.is_compiled_with_rocm() + if not hasattr(cls, "op_type"): raise AssertionError( "This test do not have op_type in class attrs, " @@ -298,7 +301,8 @@ class OpTest(unittest.TestCase): and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \ and not hasattr(cls, 'exist_fp64_check_grad') \ and not is_xpu_op_test() \ - and not is_mkldnn_op_test(): + and not is_mkldnn_op_test() \ + and not is_rocm_op_test(): raise AssertionError( "This test of %s op needs check_grad with fp64 precision." % cls.op_type)