check whether scalar condition var is on CPU before using

helinwang-patch-1
JiayiFeng 7 years ago
parent 01c5ca7364
commit 0ac43217ce

@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase {
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
return ips[0]->data<bool>()[0];
bool res;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
framework::LoDTensor cpu_tensor;
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else {
res = ips[0]->data<bool>()[0];
}
return res;
}
};

Loading…
Cancel
Save