@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) {
using paddle::platform::CPUPlace;
using paddle::platform::CUDAPlace;
DeviceContextPool& pool = DeviceContextPool::Get ();
auto cpu_dev_ctx1 = pool.Borrow (CPUPlace());
auto cpu_dev_ctx2 = pool.Borrow (CPUPlace());
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1);
DeviceContextPool& pool = DeviceContextPool::Instance ();
auto cpu_dev_ctx1 = pool.Get (CPUPlace());
auto cpu_dev_ctx2 = pool.Get (CPUPlace());
ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1);
std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(CUDAPlace(i));
}
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as CUDAPlace(i)
CUDAPlace place = boost::get<CUDAPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
auto dev_ctx = pool.Get(CUDAPlace(i));
ASSERT_NE(dev_ctx, nullptr);
}
}
int main(int argc, char** argv) {
int dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
@ -109,7 +94,7 @@ int main(int argc, char** argv) {
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create (places);
paddle::platform::DeviceContextPool::Init (places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();