|
|
|
@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) {
|
|
|
|
|
using paddle::platform::CPUPlace;
|
|
|
|
|
using paddle::platform::CUDAPlace;
|
|
|
|
|
|
|
|
|
|
DeviceContextPool& pool = DeviceContextPool::Get();
|
|
|
|
|
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace());
|
|
|
|
|
auto cpu_dev_ctx2 = pool.Borrow(CPUPlace());
|
|
|
|
|
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1);
|
|
|
|
|
DeviceContextPool& pool = DeviceContextPool::Instance();
|
|
|
|
|
auto cpu_dev_ctx1 = pool.Get(CPUPlace());
|
|
|
|
|
auto cpu_dev_ctx2 = pool.Get(CPUPlace());
|
|
|
|
|
ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1);
|
|
|
|
|
|
|
|
|
|
std::vector<Place> gpu_places;
|
|
|
|
|
int count = paddle::platform::GetCUDADeviceCount();
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
gpu_places.emplace_back(CUDAPlace(i));
|
|
|
|
|
}
|
|
|
|
|
auto dev_ctxs = pool.Borrow(gpu_places);
|
|
|
|
|
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
|
|
|
|
|
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
|
|
|
|
|
|
|
|
|
|
// check same as CUDAPlace(i)
|
|
|
|
|
CUDAPlace place = boost::get<CUDAPlace>(dev_ctx->GetPlace());
|
|
|
|
|
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
|
|
|
|
|
auto dev_ctx = pool.Get(CUDAPlace(i));
|
|
|
|
|
ASSERT_NE(dev_ctx, nullptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
|
|
|
int dev_count = paddle::platform::GetCUDADeviceCount();
|
|
|
|
|
if (dev_count <= 1) {
|
|
|
|
|
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
|
|
|
|
|
"device count is "
|
|
|
|
|
<< dev_count;
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<paddle::platform::Place> places;
|
|
|
|
|
|
|
|
|
|
places.emplace_back(paddle::platform::CPUPlace());
|
|
|
|
@ -109,7 +94,7 @@ int main(int argc, char** argv) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(0) << " DeviceCount " << count;
|
|
|
|
|
paddle::platform::DeviceContextPool::Create(places);
|
|
|
|
|
paddle::platform::DeviceContextPool::Init(places);
|
|
|
|
|
|
|
|
|
|
testing::InitGoogleTest(&argc, argv);
|
|
|
|
|
return RUN_ALL_TESTS();
|
|
|
|
|