|
|
|
@ -155,8 +155,10 @@ class Vector {
|
|
|
|
|
|
|
|
|
|
// get cuda ptr. immutable
|
|
|
|
|
const T *CUDAData(platform::Place place) const {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(place),
|
|
|
|
|
"CUDA Data must on CUDA place");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_gpu_place(place), true,
|
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
|
"Place mismatch, CUDA Data must be on CUDA place."));
|
|
|
|
|
ImmutableCUDA(place);
|
|
|
|
|
return reinterpret_cast<T *>(gpu_->ptr());
|
|
|
|
|
}
|
|
|
|
@ -234,7 +236,8 @@ class Vector {
|
|
|
|
|
UnsetFlag(kDirty);
|
|
|
|
|
SetFlag(kDataInCUDA);
|
|
|
|
|
} else if (IsInCUDA() && !(place == gpu_->place())) {
|
|
|
|
|
PADDLE_THROW("This situation should not happen");
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
platform::errors::Unavailable("Unexpected data place mismatch."));
|
|
|
|
|
// Still dirty
|
|
|
|
|
} else {
|
|
|
|
|
// Dirty && DataInCUDA && Device is same
|
|
|
|
@ -246,7 +249,8 @@ class Vector {
|
|
|
|
|
CopyCPUDataToCUDA(place);
|
|
|
|
|
SetFlag(kDataInCUDA);
|
|
|
|
|
} else if (!(place == gpu_->place())) {
|
|
|
|
|
PADDLE_THROW("This situation should not happen.");
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
platform::errors::Unavailable("Unexpected data place mismatch."));
|
|
|
|
|
} else {
|
|
|
|
|
// Not Dirty && DataInCUDA && Device is same
|
|
|
|
|
// Do nothing.
|
|
|
|
@ -501,27 +505,29 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T *CUDAData(platform::Place place) const {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Vector::CUDAData() method is not supported in CPU-only version");
|
|
|
|
|
PADDLE_THROW(platform::errors::Unavailable(
|
|
|
|
|
"Vector::CUDAData() method is not supported in CPU-only version."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T *CUDAMutableData(platform::Place place) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
PADDLE_THROW(platform::errors::Unavailable(
|
|
|
|
|
"Vector::CUDAMutableData() method is not supported in CPU-only "
|
|
|
|
|
"version");
|
|
|
|
|
"version."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T *Data(platform::Place place) const {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::is_cpu_place(place),
|
|
|
|
|
"Vector::Data() method is not supported when not in CPUPlace");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_cpu_place(place), true,
|
|
|
|
|
platform::errors::Unavailable(
|
|
|
|
|
"Vector::Data() method is not supported when not in CPUPlace."));
|
|
|
|
|
return this->data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T *MutableData(platform::Place place) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::is_cpu_place(place),
|
|
|
|
|
"Vector::MutableData() method is not supported when not in CPUPlace");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_cpu_place(place), true,
|
|
|
|
|
platform::errors::Unavailable("Vector::MutableData() method is not "
|
|
|
|
|
"supported when not in CPUPlace."));
|
|
|
|
|
return this->data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|