[ROCM] update fluid platform for rocm39 (part4), test=develop (#30936)

revert-31068-fix_conv3d_windows
Qi Li 5 years ago committed by GitHub
parent a5c56d83a1
commit 334296306c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,13 +16,20 @@ endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
endif()
if (WITH_ROCM)
hip_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
hip_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
if (WITH_GPU)
if (WITH_GPU OR WITH_ROCM)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator)
elseif(WITH_XPU)
set(AllocatorFacadeDeps xpu_info)
@ -40,6 +47,16 @@ if (WITH_GPU)
cuda_allocator
device_context
memcpy)
elseif (WITH_ROCM)
hip_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc
best_fit_allocator_test.cu
DEPS best_fit_allocator
locked_allocator
cpu_allocator
cuda_allocator
device_context
memcpy)
else()
cc_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc
@ -57,7 +74,7 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator)
if (WITH_TESTING)
if (WITH_GPU AND TARGET retry_allocator_test)
if ((WITH_GPU OR WITH_ROCM) AND TARGET retry_allocator_test)
target_link_libraries(retry_allocator_test cuda_allocator)
endif()

@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <thread> // NOLINT
#include <vector>
@ -40,8 +47,13 @@ __global__ void kernel(float *x, int n) {
void CheckKernelOutput(float *x, int n) {
auto host_x = std::unique_ptr<float[]>(new float[n]);
for (int i = 0; i < n; ++i) {
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE(hipSuccess == hipMemcpy(host_x.get(), x, n * sizeof(float),
hipMemcpyDeviceToHost));
#else
EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float),
cudaMemcpyDeviceToHost));
#endif
EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
}
@ -53,13 +65,22 @@ void MultiStreamCompute(float **data, float **second_data,
AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*data = reinterpret_cast<float *>(allocation_ptr->ptr());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *data, N);
#else
kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
#endif
// allocate and compute on same stream again
allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *second_data,
N);
#else
kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
#endif
}
TEST(Malloc, CUDADeviceContextMultiStream) {
@ -75,8 +96,12 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
float *second_data[NUM_STREAMS];
CudaDevCtxVec dev_ctx;
// default stream
// default stream
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
kernel<<<1, 64>>>(main_stream_data, N);
#endif
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
@ -85,7 +110,11 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
}
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
#endif
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
@ -106,8 +135,12 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
CudaDevCtxVec dev_ctx;
std::vector<std::thread> threads;
// default stream
// default stream
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
kernel<<<1, 64>>>(main_stream_data, N);
#endif
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
@ -120,8 +153,11 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
for (int i = 0; i < NUM_STREAMS; ++i) {
threads[i].join();
}
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
#endif
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);

@ -196,9 +196,22 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
}
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
hipStreamSynchronize(0);
#else
hipError_t e_sync = hipSuccess;
while (e_sync = hipStreamQuery(0)) {
if (e_sync == hipErrorNotReady) continue;
break;
}
#endif
}
#else
inline void SyncCUDAStream() {
#if !defined(_WIN32)
cudaStreamSynchronize(0);
@ -210,6 +223,7 @@ inline void SyncCUDAStream() {
}
#endif
}
#endif
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
@ -228,10 +242,18 @@ void Copy<platform::CPUPlace, platform::CUDAPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
#endif
// FIXME(zjl): do we really need it?
if (num <= kMaxGpuAsyncCopyBytes) {
SyncCUDAStream();
@ -250,10 +272,18 @@ void Copy<platform::CUDAPlace, platform::CPUPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
#endif
// FIXME(zjl): do we really need it?
if (num <= kMaxGpuAsyncCopyBytes) {
SyncCUDAStream();
@ -273,10 +303,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
platform::SetDeviceId(src_place.device);
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToDevice, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
#endif
}
} else {
if (stream) {
@ -332,10 +370,18 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
#endif
}
}
@ -351,10 +397,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
#endif
}
}

@ -41,27 +41,44 @@ float test_pinned_memory() {
const int iteration = 10;
// create event start and end
cudaEvent_t start_e, stop_e, copying_e;
gpuEvent_t start_e, stop_e, copying_e;
float elapsedTime = 0;
#ifdef PADDLE_WITH_HIP
hipEventCreate(&start_e);
hipEventCreate(&stop_e);
hipEventCreate(&copying_e);
#else
cudaEventCreate(&start_e);
cudaEventCreate(&stop_e);
cudaEventCreate(&copying_e);
#endif
// create computation stream, data copying stream
cudaStream_t computation_stream, copying_stream;
gpuStream_t computation_stream, copying_stream;
#ifdef PADDLE_WITH_HIP
hipStreamCreate(&computation_stream);
hipStreamCreate(&copying_stream);
#else
cudaStreamCreate(&computation_stream);
cudaStreamCreate(&copying_stream);
#endif
// create record event, pinned memory, gpu memory
std::vector<cudaEvent_t> record_event(iteration);
std::vector<gpuEvent_t> record_event(iteration);
std::vector<float*> input_pinned_mem(iteration);
std::vector<float*> gpu_mem(iteration);
std::vector<float*> output_pinned_mem(iteration);
// initial data
for (int j = 0; j < iteration; ++j) {
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags(&record_event[j], hipEventDisableTiming);
hipEventCreate(&(record_event[j]));
#else
cudaEventCreateWithFlags(&record_event[j], cudaEventDisableTiming);
cudaEventCreate(&(record_event[j]));
#endif
input_pinned_mem[j] = static_cast<float*>(
paddle::memory::Alloc(cpu_place, data_size * sizeof(float)));
output_pinned_mem[j] = static_cast<float*>(
@ -74,7 +91,11 @@ float test_pinned_memory() {
}
}
#ifdef PADDLE_WITH_HIP
hipEventRecord(start_e, computation_stream);
#else
cudaEventRecord(start_e, computation_stream);
#endif
// computation
for (int m = 0; m < 30; ++m) {
@ -88,13 +109,21 @@ float test_pinned_memory() {
// call kernel on computation stream.
Kernel<<<4, 1024, 0, computation_stream>>>(gpu_mem[i], data_size);
#ifdef PADDLE_WITH_HIP
// record event_computation on computation stream
hipEventRecord(record_event[i], computation_stream);
// wait event_computation on copy stream.
// note: this operation is async.
hipStreamWaitEvent(copying_stream, record_event[i], 0);
#else
// record event_computation on computation stream
cudaEventRecord(record_event[i], computation_stream);
// wait event_computation on copy stream.
// note: this operation is async.
cudaStreamWaitEvent(copying_stream, record_event[i], 0);
#endif
// copy data GPU->CPU, on copy stream.
// note: this operation is async for pinned memory.
paddle::memory::Copy(cpu_place, output_pinned_mem[i], cuda_place,
@ -103,6 +132,16 @@ float test_pinned_memory() {
}
}
#ifdef PADDLE_WITH_HIP
hipEventRecord(copying_e, copying_stream);
hipStreamWaitEvent(computation_stream, copying_e, 0);
hipEventRecord(stop_e, computation_stream);
hipEventSynchronize(start_e);
hipEventSynchronize(stop_e);
hipEventElapsedTime(&elapsedTime, start_e, stop_e);
#else
cudaEventRecord(copying_e, copying_stream);
cudaStreamWaitEvent(computation_stream, copying_e, 0);
@ -111,6 +150,7 @@ float test_pinned_memory() {
cudaEventSynchronize(start_e);
cudaEventSynchronize(stop_e);
cudaEventElapsedTime(&elapsedTime, start_e, stop_e);
#endif
// std::cout << cpu_place << " "
// << "time consume:" << elapsedTime / 30 << std::endl;
@ -123,12 +163,22 @@ float test_pinned_memory() {
}
}
// destroy resource
// destroy resource
#ifdef PADDLE_WITH_HIP
hipEventDestroy(copying_e);
hipEventDestroy(start_e);
hipEventDestroy(stop_e);
#else
cudaEventDestroy(copying_e);
cudaEventDestroy(start_e);
cudaEventDestroy(stop_e);
#endif
for (int j = 0; j < 10; ++j) {
#ifdef PADDLE_WITH_HIP
hipEventDestroy((record_event[j]));
#else
cudaEventDestroy((record_event[j]));
#endif
paddle::memory::Free(cpu_place, input_pinned_mem[j]);
paddle::memory::Free(cpu_place, output_pinned_mem[j]);
paddle::memory::Free(cuda_place, gpu_mem[j]);

@ -21,7 +21,7 @@ size_t Alignment(size_t size, const platform::Place &place) {
if (platform::is_cpu_place(place)) {
alignment = CpuMinChunkSize();
} else {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
alignment = GpuMinChunkSize();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(

@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/gpu_info.h"
#endif

@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
@ -337,7 +338,7 @@ void* GetNVRTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib", false);
#elif defined(PADDLE_WITH_HIP)
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprtc.so", false);
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false);
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.so", false);
#endif
@ -347,7 +348,7 @@ void* GetCUDADsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false);
#elif defined(PADDLE_WITH_HIP)
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhip_hcc.so", false);
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false);
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so", false);
#endif

@ -45,6 +45,7 @@ extern bool HasNVRTC();
* include all needed hiprtc functions
**/
#define HIPRTC_ROUTINE_EACH(__macro) \
__macro(hiprtcVersion); \
__macro(hiprtcGetErrorString); \
__macro(hiprtcCompileProgram); \
__macro(hiprtcCreateProgram); \

@ -16,10 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <miopen/miopen.h>
#include <miopen/version.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
#define MIOPEN_VERSION \
(MIOPEN_VERSION_MAJOR * 1000 + MIOPEN_VERSION_MINOR * 100 + \
MIOPEN_VERSION_PATCH) // NOLINT
namespace paddle {
namespace platform {
namespace dynload {

@ -46,6 +46,7 @@ extern bool HasCUDADriver();
* include all needed cuda driver functions
**/
#define ROCM_ROUTINE_EACH(__macro) \
__macro(hipDriverGetVersion); \
__macro(hipGetErrorString); \
__macro(hipModuleLoadData); \
__macro(hipModuleGetFunction); \

@ -18,6 +18,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include "paddle/fluid/platform/place.h"
namespace paddle {
@ -48,9 +51,9 @@ class Event {
void set_name(std::string name) { name_ = name; }
void set_role(EventRole role) { role_ = role; }
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef PADDLE_WITH_CUPTI
cudaEvent_t event() const { return event_; }
gpuEvent_t event() const { return event_; }
int device() const { return device_; }
#endif
#endif
@ -66,7 +69,7 @@ class Event {
EventRole role_{};
int64_t cpu_ns_;
bool visited_status_{false};
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUPTI
int64_t gpu_ns_ = 0;
@ -77,7 +80,7 @@ class Event {
private:
#else
cudaEvent_t event_ = nullptr;
gpuEvent_t event_ = nullptr;
int device_ = -1;
#endif
#endif

@ -13,7 +13,7 @@
// limitations under the License.
#include "gflags/gflags.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#endif
@ -45,7 +45,7 @@ DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* CUDA related related FLAG
@ -84,7 +84,7 @@ DEFINE_string(selected_gpus, "",
"share-memory only.");
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* CUDNN related FLAG
@ -167,7 +167,7 @@ DEFINE_bool(cudnn_batchnorm_spatial_persistent, false,
"batch_norm, default is False.");
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* NCCL related FLAG
@ -377,7 +377,7 @@ DEFINE_double(
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
"reserve the rest for page tables, etc");
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* Memory related FLAG

@ -40,7 +40,7 @@ struct ForRange<CPUDeviceContext> {
size_t limit_;
};
#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);

@ -16,8 +16,10 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cupti.h"
#endif
#include "paddle/fluid/platform/device_context.h"
@ -92,6 +94,7 @@ bool InitGflags(std::vector<std::string> args) {
return successed;
}
#ifdef PADDLE_WITH_CUDA
void InitCupti() {
#ifdef PADDLE_WITH_CUPTI
if (FLAGS_multiple_of_cupti_buffer_size == 1) return;
@ -117,14 +120,17 @@ void InitCupti() {
#undef MULTIPLY_ATTR_VALUE
#endif
}
#endif
void InitDevices() {
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute).
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute).
#ifdef PADDLE_WITH_CUDA
InitCupti();
#endif
/*Init all available devices by default */
std::vector<int> devices;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
try {
// use user specified GPUs in single-node multi-process mode.
devices = platform::GetSelectedDevices();
@ -154,7 +160,7 @@ void InitDevices(const std::vector<int> devices) {
continue;
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places.emplace_back(platform::CUDAPlace(devices[i]));
#endif
#ifdef PADDLE_WITH_XPU
@ -162,7 +168,7 @@ void InitDevices(const std::vector<int> devices) {
#endif
}
places.emplace_back(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places.emplace_back(platform::CUDAPinnedPlace());
#endif
platform::DeviceContextPool::Init(places);

@ -19,7 +19,8 @@ TEST(InitDevices, CPU) {
using paddle::framework::InitDevices;
using paddle::platform::DeviceContextPool;
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU) && \
!defined(PADDLE_WITH_HIP)
InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U);
@ -30,7 +31,7 @@ TEST(InitDevices, CUDA) {
using paddle::framework::InitDevices;
using paddle::platform::DeviceContextPool;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int count = paddle::platform::GetCUDADeviceCount();
InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance();

File diff suppressed because it is too large Load Diff

@ -0,0 +1,93 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "paddle/fluid/platform/miopen_helper.h"
#include <gtest/gtest.h>
TEST(MIOpenHelper, ScopedTensorDescriptor) {
using paddle::platform::ScopedTensorDescriptor;
using paddle::platform::DataLayout;
ScopedTensorDescriptor tensor_desc;
std::vector<int> shape = {2, 4, 6, 6};
auto desc = tensor_desc.descriptor<float>(DataLayout::kNCHW, shape);
miopenDataType_t type;
int nd;
std::vector<int> dims(4);
std::vector<int> strides(4);
paddle::platform::dynload::miopenGetTensorDescriptor(desc, &type, dims.data(),
strides.data());
paddle::platform::dynload::miopenGetTensorDescriptorSize(desc, &nd);
EXPECT_EQ(nd, 4);
for (size_t i = 0; i < dims.size(); ++i) {
EXPECT_EQ(dims[i], shape[i]);
}
EXPECT_EQ(strides[3], 1);
EXPECT_EQ(strides[2], 6);
EXPECT_EQ(strides[1], 36);
EXPECT_EQ(strides[0], 144);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor tensor5d_desc;
std::vector<int> shape_5d = {2, 4, 6, 6, 6};
auto desc_5d = tensor5d_desc.descriptor<float>(DataLayout::kNCDHW, shape_5d);
std::vector<int> dims_5d(5);
std::vector<int> strides_5d(5);
paddle::platform::dynload::miopenGetTensorDescriptor(
desc_5d, &type, dims_5d.data(), strides_5d.data());
paddle::platform::dynload::miopenGetTensorDescriptorSize(desc_5d, &nd);
EXPECT_EQ(nd, 5);
for (size_t i = 0; i < dims_5d.size(); ++i) {
EXPECT_EQ(dims_5d[i], shape_5d[i]);
}
EXPECT_EQ(strides_5d[4], 1);
EXPECT_EQ(strides_5d[3], 6);
EXPECT_EQ(strides_5d[2], 36);
EXPECT_EQ(strides_5d[1], 216);
EXPECT_EQ(strides_5d[0], 864);
}
TEST(MIOpenHelper, ScopedConvolutionDescriptor) {
using paddle::platform::ScopedConvolutionDescriptor;
ScopedConvolutionDescriptor conv_desc;
std::vector<int> src_pads = {2, 2, 2};
std::vector<int> src_strides = {1, 1, 1};
std::vector<int> src_dilations = {1, 1, 1};
auto desc = conv_desc.descriptor<float>(src_pads, src_strides, src_dilations);
miopenConvolutionMode_t mode;
int nd;
std::vector<int> pads(3);
std::vector<int> strides(3);
std::vector<int> dilations(3);
paddle::platform::dynload::miopenGetConvolutionNdDescriptor(
desc, 3, &nd, pads.data(), strides.data(), dilations.data(), &mode);
EXPECT_EQ(nd, 3);
for (size_t i = 0; i < src_pads.size(); ++i) {
EXPECT_EQ(pads[i], src_pads[i]);
EXPECT_EQ(strides[i], src_strides[i]);
EXPECT_EQ(dilations[i], src_dilations[i]);
}
EXPECT_EQ(mode, miopenConvolution);
}
Loading…
Cancel
Save