[ROCM] update fluid operators for rocm (part9), test=develop (#31338)

test_model_benchmark
Qi Li 4 years ago committed by GitHub
parent 6626c6a6ad
commit e312a1ff6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/p_norm_op.h"
namespace paddle {

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/prroi_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
@ -29,22 +28,6 @@ static inline int NumBlocks(const int N) {
kNumMaximumNumBlocks);
}
template <typename T>
DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff,
const int h, const int w,
const int height, const int width,
const T coeff) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow) {
paddle::platform::CudaAtomicAdd(diff + h * width + w, top_diff * coeff);
}
}
template <typename T>
DEVICE void GPUAccumulateRois(T* offset, T data) {
paddle::platform::CudaAtomicAdd(offset, data);
}
template <typename T>
__global__ void GPUPRROIPoolForward(
const int nthreads, const T* input_data, const T* input_rois,
@ -170,25 +153,23 @@ __global__ void GPUPRROIPoolBackward(
for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
PrRoIPoolingMatDistributeDiff(
PrRoIPoolingMatDistributeDiff<T>(
offset_input_grad_data, sum_out, h_iter, w_iter, h_iter + 1,
w_iter + 1, max(win_start_h, static_cast<T>(h_iter)),
max(win_start_w, static_cast<T>(w_iter)),
min(win_end_h, static_cast<T>(h_iter) + static_cast<T>(1.0)),
min(win_end_w, static_cast<T>(w_iter) + static_cast<T>(1.0)),
height, width, PrRoIPoolingDistributeDiffCUDA<T>);
height, width);
}
}
const T* offset_out_data = out_data + i;
const T* offset_in_data = in_data + input_offset;
PrRoIPoolingCoorBackward(
PrRoIPoolingCoorBackward<T>(
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w,
win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale,
offset_in_data, offset_out_data, offset_input_roi_grad_data,
offset_output_grad_data, GPUAccumulateRois<T>,
[](const T x, const T y) { return max(x, y); },
[](const T x, const T y) { return min(x, y); });
offset_output_grad_data);
}
}

@ -16,6 +16,9 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/cuda_primitives.h"
#endif
namespace paddle {
namespace operators {
@ -73,6 +76,17 @@ inline HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data,
return sum_out;
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
DEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, const int h,
const int w, const int height,
const int width, const T coeff) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow) {
paddle::platform::CudaAtomicAdd(diff + h * width + w, top_diff * coeff);
}
}
#else
template <typename T>
inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff,
const int h, const int w,
@ -84,12 +98,15 @@ inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff,
*(diff + h * width + w) += top_diff * coeff;
}
}
#endif
template <typename T, typename Functor>
HOSTDEVICE void PrRoIPoolingMatDistributeDiff(
T* diff, const T top_diff, const int s_h, const int s_w, const int e_h,
const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0,
const int w0, Functor functor) {
template <typename T>
HOSTDEVICE void PrRoIPoolingMatDistributeDiff(T* diff, const T top_diff,
const int s_h, const int s_w,
const int e_h, const int e_w,
const T y0, const T x0,
const T y1, const T x1,
const int h0, const int w0) {
T alpha, beta, lim_alpha, lim_beta, tmp;
alpha = x0 - static_cast<T>(s_w);
@ -99,14 +116,14 @@ HOSTDEVICE void PrRoIPoolingMatDistributeDiff(
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
functor(diff, top_diff, s_h, s_w, h0, w0, tmp);
PrRoIPoolingDistributeDiff<T>(diff, top_diff, s_h, s_w, h0, w0, tmp);
alpha = static_cast<T>(e_w) - x1;
lim_alpha = static_cast<T>(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
functor(diff, top_diff, s_h, e_w, h0, w0, tmp);
PrRoIPoolingDistributeDiff<T>(diff, top_diff, s_h, e_w, h0, w0, tmp);
alpha = x0 - static_cast<T>(s_w);
beta = static_cast<T>(e_h) - y1;
@ -115,20 +132,47 @@ HOSTDEVICE void PrRoIPoolingMatDistributeDiff(
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
functor(diff, top_diff, e_h, s_w, h0, w0, tmp);
PrRoIPoolingDistributeDiff<T>(diff, top_diff, e_h, s_w, h0, w0, tmp);
alpha = static_cast<T>(e_w) - x1;
lim_alpha = static_cast<T>(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha +
0.5f * alpha * alpha) *
(lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
functor(diff, top_diff, e_h, e_w, h0, w0, tmp);
PrRoIPoolingDistributeDiff<T>(diff, top_diff, e_h, e_w, h0, w0, tmp);
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
DEVICE void AccumulateRois(T* offset, T data) {
paddle::platform::CudaAtomicAdd(offset, data);
}
#else
template <typename T>
inline HOSTDEVICE void CPUAccumulateRois(T* offset, T data) {
inline HOSTDEVICE void AccumulateRois(T* offset, T data) {
*offset += data;
}
#endif
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
DEVICE T MaxFunctor(const T x, const T y) {
return max(x, y);
}
template <typename T>
DEVICE T MinFunctor(const T x, const T y) {
return min(x, y);
}
#else
template <typename T>
inline HOSTDEVICE T MaxFunctor(const T x, const T y) {
return std::max(x, y);
}
template <typename T>
inline HOSTDEVICE T MinFunctor(const T x, const T y) {
return std::max(x, y);
}
#endif
template <typename T>
inline HOSTDEVICE static T PrRoIPoolingGetCoeff(T dh, T dw) {
@ -172,15 +216,13 @@ inline HOSTDEVICE T PrRoIPoolingSingleCoorIntegral(T s, T t, T c1, T c2) {
(t - 0.5f * t * t - s + 0.5f * s * s) * c1;
}
template <typename T, typename Functor, typename MaxFunctor,
typename MinFunctor>
template <typename T>
inline HOSTDEVICE void PrRoIPoolingCoorBackward(
int s_w, int e_w, int s_h, int e_h, int width, int height, T win_start_w,
T win_start_h, T win_end_w, T win_end_h, int pw, int ph,
const int pooled_width, const int pooled_height, T win_size,
const float spatial_scale, const T* this_bottom_data,
const T* this_top_data, T* this_data_grad, const T* this_out_grad,
Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) {
const T* this_top_data, T* this_data_grad, const T* this_out_grad) {
T g_x1_y = 0.f;
T g_x2_y = 0.f;
T g_x_y1 = 0.f;
@ -188,16 +230,16 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward(
for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
g_x1_y += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_h, static_cast<T>(h_iter)) - h_iter,
minFunctor(win_end_h, static_cast<T>(h_iter + 1)) - h_iter,
MaxFunctor<T>(win_start_h, static_cast<T>(h_iter)) - h_iter,
MinFunctor<T>(win_end_h, static_cast<T>(h_iter + 1)) - h_iter,
PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w,
height, width));
g_x2_y += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_h, static_cast<T>(h_iter)) - h_iter,
minFunctor(win_end_h, static_cast<T>(h_iter + 1)) - h_iter,
MaxFunctor<T>(win_start_h, static_cast<T>(h_iter)) - h_iter,
MinFunctor<T>(win_end_h, static_cast<T>(h_iter + 1)) - h_iter,
PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w,
@ -206,16 +248,16 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward(
for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
g_x_y1 += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_w, static_cast<T>(w_iter)) - w_iter,
minFunctor(win_end_w, static_cast<T>(w_iter + 1)) - w_iter,
MaxFunctor<T>(win_start_w, static_cast<T>(w_iter)) - w_iter,
MinFunctor<T>(win_end_w, static_cast<T>(w_iter + 1)) - w_iter,
PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1,
height, width));
g_x_y2 += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_w, static_cast<T>(w_iter)) - w_iter,
minFunctor(win_end_w, static_cast<T>(w_iter + 1)) - w_iter,
MaxFunctor<T>(win_start_w, static_cast<T>(w_iter)) - w_iter,
MinFunctor<T>(win_end_w, static_cast<T>(w_iter + 1)) - w_iter,
PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1,
@ -232,19 +274,21 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward(
partial_y1 = partial_y1 / win_size * spatial_scale;
partial_y2 = partial_y2 / win_size * spatial_scale;
functor(this_data_grad + 0,
AccumulateRois<T>(
this_data_grad + 0,
(partial_x1 * (1.0 - static_cast<T>(pw) / pooled_width) +
partial_x2 * (1.0 - static_cast<T>(pw + 1) / pooled_width)) *
(*this_out_grad));
functor(this_data_grad + 1,
AccumulateRois<T>(
this_data_grad + 1,
(partial_y1 * (1.0 - static_cast<T>(ph) / pooled_height) +
partial_y2 * (1.0 - static_cast<T>(ph + 1) / pooled_height)) *
(*this_out_grad));
functor(this_data_grad + 2,
AccumulateRois<T>(this_data_grad + 2,
(partial_x2 * static_cast<T>(pw + 1) / pooled_width +
partial_x1 * static_cast<T>(pw) / pooled_width) *
(*this_out_grad));
functor(this_data_grad + 3,
AccumulateRois<T>(this_data_grad + 3,
(partial_y2 * static_cast<T>(ph + 1) / pooled_height +
partial_y1 * static_cast<T>(ph) / pooled_height) *
(*this_out_grad));
@ -516,7 +560,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
PrRoIPoolingMatDistributeDiff(
PrRoIPoolingMatDistributeDiff<T>(
offset_input_grad_data, sum_out, h_iter, w_iter, h_iter + 1,
w_iter + 1, std::max(win_start_h, static_cast<T>(h_iter)),
std::max(win_start_w, static_cast<T>(w_iter)),
@ -524,19 +568,16 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
static_cast<T>(h_iter) + static_cast<T>(1.0)),
std::min(win_end_w,
static_cast<T>(w_iter) + static_cast<T>(1.0)),
height, width, PrRoIPoolingDistributeDiff<T>);
height, width);
}
}
const T* offset_in_data = in_data + input_offset;
PrRoIPoolingCoorBackward(
PrRoIPoolingCoorBackward<T>(
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h,
win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size,
spatial_scale, offset_in_data, offset_out_data,
offset_input_roi_grad_data, offset_output_grad_data,
CPUAccumulateRois<T>,
[](const T x, const T y) { return std::max(x, y); },
[](const T x, const T y) { return std::min(x, y); });
offset_input_roi_grad_data, offset_output_grad_data);
}
}
}

@ -47,7 +47,8 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths,
hidden_size, 0);
#endif
#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB)
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \
(defined PADDLE_WITH_PSLIB)
auto hidden_size = ctx.Attr<int>("size");
auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance();
gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths,
@ -90,7 +91,8 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) {
box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values,
slot_lengths, hidden_size, 0, batch_size);
#endif
#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB)
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \
(defined PADDLE_WITH_PSLIB)
auto hidden_size = ctx.Attr<int>("size");
auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance();
gpu_ps_ptr->PushSparseGrad(ctx.GetPlace(), 0, all_keys, all_grad_values,

@ -18,7 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include <thrust/random.h>
#endif
@ -36,7 +36,7 @@ struct Random<platform::CPUDeviceContext> {
using UniformIntDist = std::uniform_int_distribution<T>;
};
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <>
struct Random<platform::CUDADeviceContext> {
using Engine = thrust::minstd_rand;

@ -50,7 +50,7 @@ __global__ void expand_input_by_rank_kernel(
}
template <typename T>
void expand_rank_attention_input(cudaStream_t stream, const T* input,
void expand_rank_attention_input(gpuStream_t stream, const T* input,
int input_row, int input_col, T* output,
int output_row, int output_col,
const int* rank_offset, int rank_offset_row,
@ -93,7 +93,7 @@ __global__ void expand_rank_attention_param_kernel(
}
template <typename T>
void expand_rank_attention_param(cudaStream_t stream, const T* input,
void expand_rank_attention_param(gpuStream_t stream, const T* input,
int input_row, int input_col,
const int* rank_offset, int rank_offset_row,
int rank_offset_col, const T* param,
@ -133,7 +133,7 @@ __global__ void merge_param_gradient_kernel(
}
template <typename T>
void merge_rank_attention_param_grad(cudaStream_t stream, T* expanded_grad,
void merge_rank_attention_param_grad(gpuStream_t stream, T* expanded_grad,
int expanded_grad_row,
int expanded_grad_col, T* param_grad,
int param_grad_row, int param_grad_col,

@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cublas.h>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"

@ -654,7 +654,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(
ops::ReshapeDoubleGradKernel, paddle::platform::complex128,
ops::ReshapeDoubleGradKernel);
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
uint8_t, ops::ReshapeKernel, int64_t,

File diff suppressed because it is too large Load Diff

@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include "paddle/fluid/operators/seed_op.h"
namespace paddle {

@ -63,7 +63,7 @@ void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, output, static_cast<T>(0));
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!cpu_place) {
Tensor length;
length.mutable_data<IndexT>(framework::make_ddim({1}),
@ -71,9 +71,15 @@ void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
IndexT* length_data = length.data<IndexT>();
const IndexT* segment_ids = segment->data<IndexT>();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
hipMemcpyDeviceToHost));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
cudaMemcpyDeviceToHost));
#endif
IndexT length_host = length_data[0];
length_host++;

@ -37,7 +37,7 @@ inline int GetBranchNumber(const framework::LoDTensor &mask) {
}
// when platform::is_gpu_place(mask.place()) is ture
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(

@ -33,7 +33,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
using Vector = framework::Vector<T>;
#else

@ -11,7 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"

@ -16,7 +16,11 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
@ -388,18 +392,30 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#endif
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward(
handle, platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data));
#endif
}
}
};
@ -496,19 +512,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#endif
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward(
handle, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(),
desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
dx_data));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
dx_data));
#endif
}
}
};
@ -518,6 +547,15 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<plat::float16>);
#else
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<double>,
@ -526,3 +564,4 @@ REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>,
ops::SoftmaxGradCUDNNKernel<plat::float16>);
#endif

@ -22,6 +22,10 @@ limitations under the License. */
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
@ -66,7 +70,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
@ -190,7 +194,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}

@ -82,7 +82,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(),

@ -98,7 +98,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * size);
} else {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& gpu_place = BOOST_GET_CONST(platform::CUDAPlace, place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);

@ -72,7 +72,7 @@ TEST(StridedMemcpy, CPUConcat) {
}
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(StridedMemcpy, GPUCrop) {
// clang-format off
int src[] = {

Loading…
Cancel
Save