|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#include <string.h> // for strdup
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/init.h"
|
|
|
|
@ -46,17 +47,23 @@ void InitDevices() {
|
|
|
|
|
|
|
|
|
|
std::vector<platform::Place> places;
|
|
|
|
|
places.emplace_back(platform::CPUPlace());
|
|
|
|
|
int count = 0;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
int count = platform::GetCUDADeviceCount();
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
places.emplace_back(platform::CUDAPlace(i));
|
|
|
|
|
try {
|
|
|
|
|
count = platform::GetCUDADeviceCount();
|
|
|
|
|
} catch (const std::exception &exp) {
|
|
|
|
|
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
|
|
|
|
|
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
places.emplace_back(platform::CUDAPlace(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool::Init(places);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|