revert-31068-fix_conv3d_windows
taixiurong 4 years ago committed by GitHub
parent 71acde9afc
commit 24873f4f77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,11 +29,11 @@ class XPURangeKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
framework::TensorCopySync(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
framework::TensorCopySync(*end_t, platform::CPUPlace(), &n);
T end = n.data<T>()[0];
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
framework::TensorCopySync(*step_t, platform::CPUPlace(), &n);
T step = n.data<T>()[0];
int64_t size = 0;

@ -14,7 +14,7 @@
# TODO: define the functions to manipulate devices
import re
import os
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv
@ -137,7 +137,9 @@ def set_device(device):
raise ValueError(
"The device should not be 'xpu', " \
"since PaddlePaddle is not compiled with XPU")
place = core.XPUPlace(ParallelEnv().dev_id)
selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
device_id = int(selected_xpus[0])
place = core.XPUPlace(device_id)
else:
avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)

Loading…
Cancel
Save