|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
|
|
|
|
|
# TODO: define the functions to manipulate devices
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid.dygraph.parallel import ParallelEnv
|
|
|
|
@ -137,7 +137,9 @@ def set_device(device):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The device should not be 'xpu', " \
|
|
|
|
|
"since PaddlePaddle is not compiled with XPU")
|
|
|
|
|
place = core.XPUPlace(ParallelEnv().dev_id)
|
|
|
|
|
selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
|
|
|
|
|
device_id = int(selected_xpus[0])
|
|
|
|
|
place = core.XPUPlace(device_id)
|
|
|
|
|
else:
|
|
|
|
|
avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
|
|
|
|
|
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
|
|
|
|
|