|
|
|
@ -35,6 +35,16 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
|
|
|
|
|
return it->second.get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<const DeviceContext*>
|
|
|
|
|
DeviceContextPool::GetAllDeviceContexts() const {
|
|
|
|
|
std::vector<const DeviceContext*> all_device_ctx;
|
|
|
|
|
all_device_ctx.reserve(device_contexts_.size());
|
|
|
|
|
for (auto& dev_ctx : device_contexts_) {
|
|
|
|
|
all_device_ctx.emplace_back(dev_ctx.second.get());
|
|
|
|
|
}
|
|
|
|
|
return all_device_ctx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DeviceContextPool::DeviceContextPool(
|
|
|
|
|
const std::vector<platform::Place>& places) {
|
|
|
|
|
PADDLE_ENFORCE_GT(places.size(), 0);
|
|
|
|
|