|
|
@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST(Device, CudnnDeviceContext) {
|
|
|
|
TEST(Device, CUDNNDeviceContext) {
|
|
|
|
using paddle::platform::CudnnDeviceContext;
|
|
|
|
using paddle::platform::CUDNNDeviceContext;
|
|
|
|
using paddle::platform::CudnnPlace;
|
|
|
|
using paddle::platform::CUDNNPlace;
|
|
|
|
if (paddle::platform::dynload::HasCUDNN()) {
|
|
|
|
if (paddle::platform::dynload::HasCUDNN()) {
|
|
|
|
int count = paddle::platform::GetCUDADeviceCount();
|
|
|
|
int count = paddle::platform::GetCUDADeviceCount();
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
CudnnDeviceContext* device_context =
|
|
|
|
CUDNNDeviceContext* device_context =
|
|
|
|
new CudnnDeviceContext(CudnnPlace(i));
|
|
|
|
new CUDNNDeviceContext(CUDNNPlace(i));
|
|
|
|
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
|
|
|
|
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
|
|
|
|
ASSERT_NE(nullptr, cudnn_handle);
|
|
|
|
ASSERT_NE(nullptr, cudnn_handle);
|
|
|
|
ASSERT_NE(nullptr, device_context->stream());
|
|
|
|
ASSERT_NE(nullptr, device_context->stream());
|
|
|
|