|
|
|
@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/platform/device_context.h"
|
|
|
|
|
#include "paddle/memory/memory.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
@ -36,6 +37,59 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
|
|
|
|
|
class EigenCudaStreamDevice : public Eigen::StreamInterface {
|
|
|
|
|
public:
|
|
|
|
|
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
|
|
|
|
|
Eigen::initializeDeviceProp();
|
|
|
|
|
}
|
|
|
|
|
~EigenCudaStreamDevice() override {}
|
|
|
|
|
|
|
|
|
|
void Reinitialize(const cudaStream_t* cuda_stream, GPUPlace place) {
|
|
|
|
|
stream_ = cuda_stream;
|
|
|
|
|
place_ = place;
|
|
|
|
|
device_prop_ = &Eigen::m_deviceProperties[place.device];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const cudaStream_t& stream() const override { return *stream_; }
|
|
|
|
|
|
|
|
|
|
const cudaDeviceProp& deviceProperties() const override {
|
|
|
|
|
return *device_prop_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void* allocate(size_t num_bytes) const override {
|
|
|
|
|
return paddle::memory::Alloc(place_, num_bytes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void deallocate(void* buffer) const override {
|
|
|
|
|
paddle::memory::Free(place_, buffer);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void* scratchpad() const override {
|
|
|
|
|
if (scratch_ == NULL) {
|
|
|
|
|
scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
|
|
|
|
|
}
|
|
|
|
|
return scratch_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned int* semaphore() const override {
|
|
|
|
|
if (semaphore_ == NULL) {
|
|
|
|
|
char* scratch =
|
|
|
|
|
static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
|
|
|
|
|
semaphore_ = reinterpret_cast<unsigned int*>(scratch);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
|
|
|
|
|
}
|
|
|
|
|
return semaphore_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
GPUPlace place_;
|
|
|
|
|
const cudaStream_t* stream_; // not owned;
|
|
|
|
|
const cudaDeviceProp* device_prop_; // not owned;
|
|
|
|
|
mutable void* scratch_;
|
|
|
|
|
mutable unsigned int* semaphore_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
|
|
|
|
|
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
|
|
|
|
@ -43,19 +97,9 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
|
|
|
|
|
|
|
|
|
|
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
|
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
|
|
|
|
|
// here will cause segment fault. We must implement a class derived from
|
|
|
|
|
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
|
|
|
|
|
// later. Please refer to the implementation of class EigenCudaStreamDevice
|
|
|
|
|
// in TensorFlow.
|
|
|
|
|
//
|
|
|
|
|
// We find that CUDA 7 introduces a new option, the per-thread default stream,
|
|
|
|
|
// that has two effects. Please refer to https://devblogs.nvidia.com/
|
|
|
|
|
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
|
|
|
|
|
//
|
|
|
|
|
// So, we decide to use default stream and add –default-stream per-thread nvcc
|
|
|
|
|
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
|
|
|
|
|
eigen_stream_.reset(new Eigen::CudaStreamDevice());
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
|
|
|
|
|
eigen_stream_.reset(new EigenCudaStreamDevice());
|
|
|
|
|
eigen_stream_->Reinitialize(&stream_, place);
|
|
|
|
|
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -75,12 +119,13 @@ CUDADeviceContext::~CUDADeviceContext() {
|
|
|
|
|
}
|
|
|
|
|
eigen_stream_.reset();
|
|
|
|
|
eigen_device_.reset();
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Place CUDADeviceContext::GetPlace() const { return place_; }
|
|
|
|
|
|
|
|
|
|
void CUDADeviceContext::Wait() const {
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(0));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
|
|
|
|
@ -91,6 +136,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() {
|
|
|
|
|
if (!cublas_handle_) {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
|
|
|
|
|
}
|
|
|
|
|
return cublas_handle_;
|
|
|
|
|
}
|
|
|
|
@ -99,10 +145,13 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
|
|
|
|
|
if (!cudnn_handle_) {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
|
|
|
|
|
}
|
|
|
|
|
return cudnn_handle_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudaStream_t CUDADeviceContext::stream() { return stream_; }
|
|
|
|
|
|
|
|
|
|
curandGenerator_t CUDADeviceContext::curand_generator() {
|
|
|
|
|
if (!curand_generator_) {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
@ -110,6 +159,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
|
|
|
|
|
CURAND_RNG_PSEUDO_DEFAULT));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
|
|
|
|
|
}
|
|
|
|
|
return curand_generator_;
|
|
|
|
|
}
|
|
|
|
|