From d525abed955b5dd2e6c711205c11ac6a3bcca789 Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Mon, 21 Aug 2017 13:43:07 +0800 Subject: [PATCH 1/8] refine random related ops --- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/gaussian_random_op.cc | 35 ++---------- paddle/operators/gaussian_random_op.cu | 41 ++------------ paddle/operators/gaussian_random_op.h | 38 +++++++++++++ paddle/operators/math/math_function.cc | 22 ++++++++ paddle/operators/math/math_function.cu | 36 ++++++++++++ paddle/operators/math/math_function.h | 8 +++ paddle/operators/mul_op.cc | 1 - paddle/operators/uniform_random_op.cc | 39 ++----------- paddle/operators/uniform_random_op.cu | 55 +------------------ paddle/operators/uniform_random_op.h | 38 +++++++++++++ paddle/platform/device_context.cc | 36 ++++++------ paddle/platform/device_context.h | 20 ++++--- .../paddle/v2/framework/tests/CMakeLists.txt | 2 +- .../tests/test_gaussian_random_op.py | 7 +-- .../framework/tests/test_uniform_random_op.py | 7 +-- 16 files changed, 192 insertions(+), 197 deletions(-) create mode 100644 paddle/operators/gaussian_random_op.h create mode 100644 paddle/operators/uniform_random_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e4..8f22a5fbc3 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) @@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu) + SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f30bbce958..aba8c6e5cd 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -12,36 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include <random> -#include "paddle/framework/op_registry.h" +#include "paddle/operators/gaussian_random_op.h" namespace paddle { namespace operators { -template <typename T> -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr<float>("mean"); - float std = context.op_.GetAttr<float>("std"); - auto* tensor = context.Output<framework::Tensor>(0); - T* data = tensor->mutable_data<T>(context.GetPlace()); - - // TODO(dzh): attribute does not support unsigned int. - // And we need a global random seed configuration. - int seed = context.op_.GetAttr<int>("seed"); - if (seed == 0) { - seed = std::random_device()(); - } - std::mt19937 g(seed); - std::normal_distribution<T> distribution(mean, std); - ssize_t size = framework::product(tensor->dims()); - for (int i = 0; i < size; ++i) { - data[i] = distribution(g); - } - } -}; - class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -70,10 +45,6 @@ Use to initialize tensor with gaussian random generator. AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<float>("mean", "mean value of random.").SetDefault(.0f); AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f); - AddAttr<int>("seed", - "Random seed of generator." - "0 means use system wide seed") - .SetDefault(0); } }; @@ -83,4 +54,6 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); +REGISTER_OP_CPU_KERNEL( + gaussian_random, + ops::GaussianRandomKernel<paddle::platform::CPUPlace, float>); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1340b1e1e9..31be16fdc8 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -12,42 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include <memory> -#include <random> -#include "paddle/platform/dynload/curand.h" -#include "paddle/platform/gpu_info.h" - -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template <typename T> -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr<float>("mean"); - float std = context.op_.GetAttr<float>("std"); - auto* tensor = context.Output<framework::Tensor>(0); - T* data = tensor->mutable_data<T>(context.GetPlace()); - - int seed = context.op_.GetAttr<int>("seed"); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - curandGenerator_t g; - PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( - &g, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - platform::dynload::curandGenerateNormal( - g, data, framework::product(tensor->dims()), mean, std); - } -}; - -} // namespace operators -} // namespace paddle +#include "paddle/operators/gaussian_random_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); +REGISTER_OP_GPU_KERNEL( + gaussian_random, + ops::GaussianRandomKernel<paddle::platform::GPUPlace, float>); diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h new file mode 100644 index 0000000000..041390e954 --- /dev/null +++ b/paddle/operators/gaussian_random_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +template <typename Place, typename T> +class GaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output<framework::Tensor>("Out"); + T* data = tensor->mutable_data<T>(context.GetPlace()); + T mean = static_cast<T>(context.op_.GetAttr<float>("mean")); + T std = static_cast<T>(context.op_.GetAttr<float>("std")); + auto n = framework::product(tensor->dims()); + + auto* device_context = + const_cast<platform::DeviceContext*>(context.device_context_); + math::RandGaussian<Place, T>(n, mean, std, data, device_context); + } +}; +} +} diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1e86fc3d16..da59044899 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,6 +109,28 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, matrix_b.data<double>(), beta, matrix_out->data<double>(), context); } +template <> +void RandUniform<platform::CPUPlace, float>(const int n, const float min, + const float max, float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); + std::uniform_real_distribution<float> distribution(min, max); + for (int i = 0; i < n; i++) { + output[i] = distribution(cpu_context->rand_engine()); + } +} + +template <> +void RandGaussian<platform::CPUPlace, float>(const int n, const float mean, + const float std, float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); + std::normal_distribution<float> distribution(mean, std); + for (int i = 0; i < n; i++) { + output[i] = distribution(cpu_context->rand_engine()); + } +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index da40b27c94..5a400d4445 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include <thrust/device_ptr.h> +#include <thrust/iterator/counting_iterator.h> +#include <thrust/random.h> +#include <thrust/transform.h> #include "paddle/operators/math/math_function.h" namespace paddle { @@ -122,6 +126,38 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, matrix_b.data<double>(), beta, matrix_out->data<double>(), context); } +template <> +void RandUniform<platform::GPUPlace, float>(const int n, const float min, + const float max, float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); + thrust::uniform_real_distribution<float> distribution(min, max); + thrust::minstd_rand engine = cuda_context->rand_enigne(); + engine->discard(n); + + thrust::counting_iterator<unsigned int> index_sequence_begin(0); + + thrust::transform(thrust::cuda::par.on(cuda_context->stream()), + index_sequence_begin, index_sequence_begin + n, + thrust::device_ptr<float>(output), distribution(engine)); +} + +template <> +void RandGaussian<platform::GPUPlace, float>(const int n, const float mean, + const float std, float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); + thrust::normal_distribution<float> distribution(mean, std); + thrust::minstd_rand engine = cuda_context->rand_enigne(); + engine->discard(n); + + thrust::counting_iterator<unsigned int> index_sequence_begin(0); + + thrust::transform(thrust::cuda::par.on(cuda_context->stream()), + index_sequence_begin, index_sequence_begin + n, + thrust::device_ptr<float>(output), distribution(engine)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 155589fadb..ea15e8fd2b 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -77,6 +77,14 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +template <typename Place, typename T> +void RandUniform(const int n, const T min, const T max, T* output, + platform::DeviceContext* context); + +template <typename Place, typename T> +void RandGaussian(const int n, const T mean, const T std, T* output, + platform::DeviceContext* context); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 460e458ca4..173cc3850c 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index a0a0d4d914..81487a6bd8 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -12,39 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include <random> -#include <type_traits> -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/uniform_random_op.h" namespace paddle { namespace operators { -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template <typename T> -class CPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output<framework::Tensor>("Out"); - T* data = tensor->mutable_data<T>(context.GetPlace()); - unsigned int seed = - static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); - std::uniform_real_distribution<T> dist( - static_cast<T>(context.op_.GetAttr<float>("min")), - static_cast<T>(context.op_.GetAttr<float>("max"))); - for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { - data[i] = dist(engine); - } - } -}; - class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -72,10 +44,6 @@ Used to initialize tensor with uniform random generator. AddAttr<std::vector<int>>("dims", "the dimension of random tensor"); AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f); - AddAttr<int>("seed", - "Random seed of uniform random. " - "0 means generate a seed by system") - .SetDefault(0); } }; } // namespace operators @@ -83,5 +51,6 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); -REGISTER_OP_CPU_KERNEL(uniform_random, - paddle::operators::CPUUniformRandomKernel<float>); +REGISTER_OP_CPU_KERNEL( + uniform_random, + paddle::operators::UniformRandomKernel<paddle::platform::CPUPlace, float>); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 7a243555b6..91368fa73e 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -12,60 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include <thrust/device_ptr.h> -#include <thrust/iterator/counting_iterator.h> -#include <thrust/random.h> -#include <thrust/transform.h> -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/uniform_random_op.h" namespace paddle { namespace operators { -template <typename T> -struct UniformGenerator { - T min_, max_; - unsigned int seed_; - - __host__ __device__ UniformGenerator(T min, T max, int seed) - : min_(min), max_(max), seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution<T> dist(min_, max_); - rng.discard(n); - return dist(rng); - } -}; - -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template <typename T> -class GPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output<framework::Tensor>("Out"); - T* data = tensor->mutable_data<T>(context.GetPlace()); - unsigned int seed = - static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - T min = static_cast<T>(context.op_.GetAttr<float>("min")); - T max = static_cast<T>(context.op_.GetAttr<float>("max")); - thrust::counting_iterator<unsigned int> index_sequence_begin(0); - ssize_t N = framework::product(tensor->dims()); - thrust::transform(index_sequence_begin, index_sequence_begin + N, - thrust::device_ptr<T>(data), - UniformGenerator<T>(min, max, seed)); - } -}; - -} // namespace operators -} // namespace paddle - REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel<float>); + paddle::operators::GPUUniformRandomKernel< + paddle::platform::GPUPlace, float>); diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h new file mode 100644 index 0000000000..ec009b025e --- /dev/null +++ b/paddle/operators/uniform_random_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +template <typename Place, typename T> +class UniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output<framework::Tensor>("Out"); + T* data = tensor->mutable_data<T>(context.GetPlace()); + T min = static_cast<T>(context.op_.GetAttr<float>("min")); + T max = static_cast<T>(context.op_.GetAttr<float>("max")); + auto n = framework::product(tensor->dims()); + + auto* device_context = + const_cast<platform::DeviceContext*>(context.device_context_); + math::RandUniform<Place, T>(n, min, max, data, device_context); + } +}; +} +} diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index f92c15ae45..fabbb55443 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,8 +25,17 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) { eigen_device_.reset(new Eigen::DefaultDevice()); + rand_seed_ = rand_seed; +} + +std::minstd_rand& CPUDeviceContext::rand_engine() { + if (!rand_engine_) { + rand_engine_.reset(new std::minstd_rand()); + rand_engine_->seed(rand_seed_); + } + return *(rand_engine_.get()); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { @@ -95,7 +104,8 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); } -CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { +CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) + : place_(place), seed_(seed) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -114,9 +124,6 @@ CUDADeviceContext::~CUDADeviceContext() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } - if (curand_generator_) { - PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); - } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -150,21 +157,16 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -cudaStream_t CUDADeviceContext::stream() { return stream_; } - -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); +thrust::minstd_rand& CPUDeviceContext::rand_engine() { + if (!rand_engine_) { + rand_engine_.reset(new thrust::minstd_rand()); + rand_engine_->seed(rand_seed_); } - return curand_generator_; + return *(rand_engine_.get()); } +cudaStream_t CUDADeviceContext::stream() { return stream_; } + #endif // PADDLE_ONLY_CPU } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index c5042ae33e..e4de3807cd 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -15,9 +15,10 @@ limitations under the License. */ #include "paddle/platform/place.h" #ifndef PADDLE_ONLY_CPU +#include <thrust/device_ptr.h> +#include <thrust/random.h> #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,14 +41,18 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; + std::minstd_rand& rand_engine(); + Place GetPlace() const override; private: + int rand_seed_; + std::unique_ptr<std::minstd_rand> rand_engine_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_; }; @@ -56,7 +61,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -75,8 +80,7 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); + thrust::minstd_rand& CPUDeviceContext::rand_engine(); /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); @@ -85,18 +89,16 @@ class CUDADeviceContext : public DeviceContext { private: GPUPlace place_; - private: std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; - private: - uint64_t seed_; + uint64_t rand_seed_; + std::unique_ptr<thrust::minstd_rand> rand_engine_; // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a07130..b07a65f4d1 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) -# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index f95ed70b58..367d21b301 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -17,12 +17,7 @@ class GaussianRandomTest(unittest.TestCase): scope.new_var("Out").get_tensor() op = Operator( - "gaussian_random", - Out="Out", - dims=[1000, 784], - mean=.0, - std=1., - seed=10) + "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.) op.infer_shape(scope) context = core.DeviceContext.create(place) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index c3d2bb44da..95c36a27cf 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -17,12 +17,7 @@ class UniformRandomTest(unittest.TestCase): scope.new_var("X").get_tensor() op = Operator( - "uniform_random", - Out="X", - dims=[1000, 784], - min=-5.0, - max=10.0, - seed=10) + "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0) op.infer_shape(scope) ctx = core.DeviceContext.create(place) From 7c274dc0a16b77fae0faf527ef02a1f72abad593 Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Mon, 21 Aug 2017 16:41:22 +0800 Subject: [PATCH 2/8] use curand --- paddle/operators/math/math_function.cc | 9 +++++ paddle/operators/math/math_function.cu | 56 ++++++++++++++++++-------- paddle/operators/math/math_function.h | 8 ++++ paddle/platform/device_context.cc | 15 ++++--- paddle/platform/device_context.h | 6 +-- 5 files changed, 70 insertions(+), 24 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index da59044899..d0b1f8ee48 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,6 +109,15 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, matrix_b.data<double>(), beta, matrix_out->data<double>(), context); } +template <> +void Set<typename CPUPlace, typename float>(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); + framework::EigenVector::Type<T> out(output, n); + out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha)); +} + template <> void RandUniform<platform::CPUPlace, float>(const int n, const float min, const float max, float* output, diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 5a400d4445..76bbf790db 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -126,20 +126,48 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, matrix_b.data<double>(), beta, matrix_out->data<double>(), context); } +template <> +void Set<typename GPUPlace, typename float>(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); + framework::EigenVector::Type<T> out(output, n); + out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha)); +} + +template <typename T> +__global__ void UniformShift(const int n, const T min, const T max, T* x) { + float scale = max - min; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + x[i] = x[i] * scale + min; + } +} + template <> void RandUniform<platform::GPUPlace, float>(const int n, const float min, const float max, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - thrust::uniform_real_distribution<float> distribution(min, max); - thrust::minstd_rand engine = cuda_context->rand_enigne(); - engine->discard(n); - - thrust::counting_iterator<unsigned int> index_sequence_begin(0); + PADDLE_ENFORCE( + curandGenerateUniform(cuda_context->curand_generator(), output, n)); + int block = 512; + int grid = (n + block - 1) / block; + UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max, + output); +} - thrust::transform(thrust::cuda::par.on(cuda_context->stream()), - index_sequence_begin, index_sequence_begin + n, - thrust::device_ptr<float>(output), distribution(engine)); +template <typename T> +int HandleOddLengthRandGaussian(const int n, const T mean, const T std, + T* output, CUDADeviceContext* context) { + if (n % 2 == 1) { + std::default_random_engine generator; + std::normal_distribution<T> distribution(mean, std); + const T random_value = distribution(generator); + Set<T, platform::GPUPlace>(1, random_value, output + (n - 1), context); + return n - 1; + } + return n; } template <> @@ -147,15 +175,11 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean, const float std, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - thrust::normal_distribution<float> distribution(mean, std); - thrust::minstd_rand engine = cuda_context->rand_enigne(); - engine->discard(n); - - thrust::counting_iterator<unsigned int> index_sequence_begin(0); - thrust::transform(thrust::cuda::par.on(cuda_context->stream()), - index_sequence_begin, index_sequence_begin + n, - thrust::device_ptr<float>(output), distribution(engine)); + const int even_n = + HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context); + PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output, + even_n, mean, std)); } } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index ea15e8fd2b..afe6de7483 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -54,6 +54,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" +#include "paddle/platform/eigen.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -77,6 +78,13 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +template <typename Place, typename T> +void Set(const int n, const T alpha, T* output, + platform::DeviceContext* context) { + framework::EigenVector::Type<T> out(output, n); + out.device(*(context->eigen_device())) = t.constant(T(alpha)); +} + template <typename Place, typename T> void RandUniform(const int n, const T min, const T max, T* output, platform::DeviceContext* context); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index fabbb55443..5fd93555a5 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -157,12 +157,17 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -thrust::minstd_rand& CPUDeviceContext::rand_engine() { - if (!rand_engine_) { - rand_engine_.reset(new thrust::minstd_rand()); - rand_engine_->seed(rand_seed_); +curandGenerator_t CUDADeviceContext::curand_generator() { + if (!curand_generator_) { + SetDeviceId(place_.device); + PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, + CURAND_RNG_PSEUDO_DEFAULT)); + PADDLE_ENFORCE( + dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + + PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } - return *(rand_engine_.get()); + return curand_generator_; } cudaStream_t CUDADeviceContext::stream() { return stream_; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index e4de3807cd..7013343a8d 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -15,10 +15,9 @@ limitations under the License. */ #include "paddle/platform/place.h" #ifndef PADDLE_ONLY_CPU -#include <thrust/device_ptr.h> -#include <thrust/random.h> #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -80,7 +79,8 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - thrust::minstd_rand& CPUDeviceContext::rand_engine(); + /*! \brief Return curand handle in the device context. */ + curandGenerator_t curand_generator(); /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); From 2f47f35b3efec36189a4c6757490b897130d3028 Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Mon, 21 Aug 2017 09:12:25 +0000 Subject: [PATCH 3/8] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 4 ++-- paddle/operators/math/math_function.cc | 10 +++++----- paddle/operators/math/math_function.cu | 15 ++++++++------- paddle/operators/math/math_function.h | 7 ++----- paddle/operators/uniform_random_op.cu | 9 +++------ paddle/platform/device_context.cc | 10 +++++----- paddle/platform/device_context.h | 6 +++--- 7 files changed, 28 insertions(+), 33 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ed51d416ed..228f463f2b 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3) else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index d0b1f8ee48..a098e02f95 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -110,12 +110,12 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, } template <> -void Set<typename CPUPlace, typename float>(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { +void Set<platform::CPUPlace, float>(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); - framework::EigenVector::Type<T> out(output, n); - out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha)); + framework::EigenVector<float>::Type out(output, n); + out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); } template <> diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 76bbf790db..3ff622f308 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -127,12 +127,12 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, } template <> -void Set<typename GPUPlace, typename float>(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { +void Set<platform::GPUPlace, float>(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - framework::EigenVector::Type<T> out(output, n); - out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha)); + framework::EigenVector<float>::Type out(output, n); + out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); } template <typename T> @@ -159,12 +159,13 @@ void RandUniform<platform::GPUPlace, float>(const int n, const float min, template <typename T> int HandleOddLengthRandGaussian(const int n, const T mean, const T std, - T* output, CUDADeviceContext* context) { + T* output, + platform::CUDADeviceContext* context) { if (n % 2 == 1) { std::default_random_engine generator; std::normal_distribution<T> distribution(mean, std); const T random_value = distribution(generator); - Set<T, platform::GPUPlace>(1, random_value, output + (n - 1), context); + Set<platform::GPUPlace, T>(1, random_value, output + (n - 1), context); return n - 1; } return n; diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index afe6de7483..6543a1b515 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,9 +52,9 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include <cmath> +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" -#include "paddle/platform/eigen.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -80,10 +80,7 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, template <typename Place, typename T> void Set(const int n, const T alpha, T* output, - platform::DeviceContext* context) { - framework::EigenVector::Type<T> out(output, n); - out.device(*(context->eigen_device())) = t.constant(T(alpha)); -} + platform::DeviceContext* context); template <typename Place, typename T> void RandUniform(const int n, const T min, const T max, T* output, diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 91368fa73e..1bfffc4778 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -14,9 +14,6 @@ #include "paddle/operators/uniform_random_op.h" -namespace paddle { -namespace operators { - -REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel< - paddle::platform::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL( + uniform_random, + paddle::operators::UniformRandomKernel<paddle::platform::GPUPlace, float>); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 5fd93555a5..ad9b4e42f3 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,9 +25,9 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place, int seed) { eigen_device_.reset(new Eigen::DefaultDevice()); - rand_seed_ = rand_seed; + rand_seed_ = seed; } std::minstd_rand& CPUDeviceContext::rand_engine() { @@ -105,7 +105,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { } CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) - : place_(place), seed_(seed) { + : place_(place), rand_seed_(seed) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -162,8 +162,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( + curand_generator_, rand_seed_)); PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7013343a8d..e18f48fef5 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -40,7 +40,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0); + explicit CPUDeviceContext(CPUPlace place, int seed = 0); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -60,7 +60,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0); + explicit CUDADeviceContext(GPUPlace place, uint64_t seed = 0); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -93,12 +93,12 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; uint64_t rand_seed_; - std::unique_ptr<thrust::minstd_rand> rand_engine_; // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; + curandGenerator_t curand_generator_{nullptr}; // clang-format on }; From 08c987d7c086e4176a27f2685712bbb9226e635e Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Mon, 21 Aug 2017 17:23:15 +0800 Subject: [PATCH 4/8] use dynload curand --- paddle/operators/gaussian_random_op.h | 4 ++-- paddle/operators/math/math_function.cu | 8 ++++---- paddle/operators/uniform_random_op.h | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h index 041390e954..c90b665fe0 100644 --- a/paddle/operators/gaussian_random_op.h +++ b/paddle/operators/gaussian_random_op.h @@ -34,5 +34,5 @@ class GaussianRandomKernel : public framework::OpKernel { math::RandGaussian<Place, T>(n, mean, std, data, device_context); } }; -} -} +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 3ff622f308..908efe9e0f 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -149,8 +149,8 @@ void RandUniform<platform::GPUPlace, float>(const int n, const float min, const float max, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - PADDLE_ENFORCE( - curandGenerateUniform(cuda_context->curand_generator(), output, n)); + PADDLE_ENFORCE(platform::dynload::curandGenerateUniform( + cuda_context->curand_generator(), output, n)); int block = 512; int grid = (n + block - 1) / block; UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max, @@ -179,8 +179,8 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean, const int even_n = HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context); - PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output, - even_n, mean, std)); + PADDLE_ENFORCE(platform::dynload::curandGenerateNormal( + cuda_context->curand_generator(), output, even_n, mean, std)); } } // namespace math diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h index ec009b025e..dffa640f84 100644 --- a/paddle/operators/uniform_random_op.h +++ b/paddle/operators/uniform_random_op.h @@ -34,5 +34,5 @@ class UniformRandomKernel : public framework::OpKernel { math::RandUniform<Place, T>(n, min, max, data, device_context); } }; -} -} +} // namespace operators +} // namespace paddle From b054392e2abebb2a55dabeeb2f12e414bbc2c5af Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Mon, 21 Aug 2017 17:46:46 +0800 Subject: [PATCH 5/8] fix gaussion op bug --- paddle/operators/gaussian_random_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index aba8c6e5cd..899f05fa47 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -23,7 +23,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output<framework::Tensor>(0); + auto* tensor = context.Output<framework::Tensor>("Out"); auto dims = GetAttr<std::vector<int>>("dims"); PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); From 36e8e725669a20b272f9ace1cf7c9df646c840a3 Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Tue, 22 Aug 2017 11:40:57 +0800 Subject: [PATCH 6/8] expose random seed to users --- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/gaussian_random_op.cc | 42 ++++++++++--- paddle/operators/gaussian_random_op.cu | 61 +++++++++++++++--- paddle/operators/gaussian_random_op.h | 38 ----------- paddle/operators/math/math_function.cc | 22 ------- paddle/operators/math/math_function.cu | 48 -------------- paddle/operators/math/math_function.h | 8 --- paddle/operators/uniform_random_op.cc | 44 ++++++++++--- paddle/operators/uniform_random_op.cu | 63 ++++++++++++++++--- paddle/operators/uniform_random_op.h | 38 ----------- paddle/platform/device_context.cc | 27 +------- paddle/platform/device_context.h | 15 +---- .../tests/test_gaussian_random_op.py | 7 ++- .../framework/tests/test_uniform_random_op.py | 7 ++- 14 files changed, 196 insertions(+), 228 deletions(-) delete mode 100644 paddle/operators/gaussian_random_op.h delete mode 100644 paddle/operators/uniform_random_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8f22a5fbc3..a7c89787e4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) @@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function) + SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 899f05fa47..dcd2237459 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -1,22 +1,44 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gaussian_random_op.h" +#include <random> +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { +template <typename T> +class CPUGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.op_.GetAttr<float>("mean"); + float std = context.op_.GetAttr<float>("std"); + auto* tensor = context.Output<framework::Tensor>("Out"); + T* data = tensor->mutable_data<T>(context.GetPlace()); + + unsigned int seed = + static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::normal_distribution<T> dist(mean, std); + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } +}; + class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -43,8 +65,12 @@ Use to initialize tensor with gaussian random generator. )DOC"); AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); - AddAttr<float>("mean", "mean value of random.").SetDefault(.0f); - AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f); + AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f); + AddAttr<int>("seed", + "Random seed of generator." + "0 means use system wide seed") + .SetDefault(0); } }; @@ -54,6 +80,4 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL( - gaussian_random, - ops::GaussianRandomKernel<paddle::platform::CPUPlace, float>); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>); \ No newline at end of file diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 31be16fdc8..1d312e7b5d 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -1,20 +1,65 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gaussian_random_op.h" +#include <thrust/device_ptr.h> +#include <thrust/iterator/counting_iterator.h> +#include <thrust/random.h> +#include <thrust/transform.h> +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template <typename T> +struct GaussianGenerator { + T mean_, std_; + unsigned int seed_; + + __host__ __device__ GaussianGenerator(T mean, T std, int seed) + : mean_(mean), std_(std), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::normal_distribution<T> dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +template <typename T> +class GPUGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output<framework::Tensor>("Out"); + T* data = tensor->mutable_data<T>(context.GetPlace()); + unsigned int seed = + static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T mean = static_cast<T>(context.op_.GetAttr<float>("mean")); + T std = static_cast<T>(context.op_.GetAttr<float>("std")); + thrust::counting_iterator<unsigned int> index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr<T>(data), + GaussianGenerator<T>(mean, std, seed)); + } +}; + +} // namespace operators +} // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - gaussian_random, - ops::GaussianRandomKernel<paddle::platform::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel<float>); \ No newline at end of file diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h deleted file mode 100644 index c90b665fe0..0000000000 --- a/paddle/operators/gaussian_random_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -template <typename Place, typename T> -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output<framework::Tensor>("Out"); - T* data = tensor->mutable_data<T>(context.GetPlace()); - T mean = static_cast<T>(context.op_.GetAttr<float>("mean")); - T std = static_cast<T>(context.op_.GetAttr<float>("std")); - auto n = framework::product(tensor->dims()); - - auto* device_context = - const_cast<platform::DeviceContext*>(context.device_context_); - math::RandGaussian<Place, T>(n, mean, std, data, device_context); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index a098e02f95..d9824e5f96 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -118,28 +118,6 @@ void Set<platform::CPUPlace, float>(const int n, const float alpha, out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); } -template <> -void RandUniform<platform::CPUPlace, float>(const int n, const float min, - const float max, float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); - std::uniform_real_distribution<float> distribution(min, max); - for (int i = 0; i < n; i++) { - output[i] = distribution(cpu_context->rand_engine()); - } -} - -template <> -void RandGaussian<platform::CPUPlace, float>(const int n, const float mean, - const float std, float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); - std::normal_distribution<float> distribution(mean, std); - for (int i = 0; i < n; i++) { - output[i] = distribution(cpu_context->rand_engine()); - } -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 908efe9e0f..9dff6f05fb 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -135,54 +135,6 @@ void Set<platform::GPUPlace, float>(const int n, const float alpha, out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); } -template <typename T> -__global__ void UniformShift(const int n, const T min, const T max, T* x) { - float scale = max - min; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; - i += blockDim.x * gridDim.x) { - x[i] = x[i] * scale + min; - } -} - -template <> -void RandUniform<platform::GPUPlace, float>(const int n, const float min, - const float max, float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - PADDLE_ENFORCE(platform::dynload::curandGenerateUniform( - cuda_context->curand_generator(), output, n)); - int block = 512; - int grid = (n + block - 1) / block; - UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max, - output); -} - -template <typename T> -int HandleOddLengthRandGaussian(const int n, const T mean, const T std, - T* output, - platform::CUDADeviceContext* context) { - if (n % 2 == 1) { - std::default_random_engine generator; - std::normal_distribution<T> distribution(mean, std); - const T random_value = distribution(generator); - Set<platform::GPUPlace, T>(1, random_value, output + (n - 1), context); - return n - 1; - } - return n; -} - -template <> -void RandGaussian<platform::GPUPlace, float>(const int n, const float mean, - const float std, float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - - const int even_n = - HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context); - PADDLE_ENFORCE(platform::dynload::curandGenerateNormal( - cuda_context->curand_generator(), output, even_n, mean, std)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 6543a1b515..a0e9660564 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -82,14 +82,6 @@ template <typename Place, typename T> void Set(const int n, const T alpha, T* output, platform::DeviceContext* context); -template <typename Place, typename T> -void RandUniform(const int n, const T min, const T max, T* output, - platform::DeviceContext* context); - -template <typename Place, typename T> -void RandGaussian(const int n, const T mean, const T std, T* output, - platform::DeviceContext* context); - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 81487a6bd8..876b3ef557 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -1,22 +1,48 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include <random> +#include <type_traits> +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template <typename T> +class CPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output<framework::Tensor>("Out"); + T* data = tensor->mutable_data<T>(context.GetPlace()); + unsigned int seed = + static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution<T> dist( + static_cast<T>(context.op_.GetAttr<float>("min")), + static_cast<T>(context.op_.GetAttr<float>("max"))); + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } +}; + class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -38,12 +64,15 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC(Uniform random operator. - Used to initialize tensor with uniform random generator. )DOC"); AddAttr<std::vector<int>>("dims", "the dimension of random tensor"); AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f); + AddAttr<int>("seed", + "Random seed of uniform random. " + "0 means generate a seed by system") + .SetDefault(0); } }; } // namespace operators @@ -51,6 +80,5 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); -REGISTER_OP_CPU_KERNEL( - uniform_random, - paddle::operators::UniformRandomKernel<paddle::platform::CPUPlace, float>); +REGISTER_OP_CPU_KERNEL(uniform_random, + paddle::operators::CPUUniformRandomKernel<float>); \ No newline at end of file diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 1bfffc4778..6716b7c7f2 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -1,19 +1,68 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include <thrust/device_ptr.h> +#include <thrust/iterator/counting_iterator.h> +#include <thrust/random.h> +#include <thrust/transform.h> +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template <typename T> +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution<T> dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template <typename T> +class GPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output<framework::Tensor>("Out"); + T* data = tensor->mutable_data<T>(context.GetPlace()); + unsigned int seed = + static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T min = static_cast<T>(context.op_.GetAttr<float>("min")); + T max = static_cast<T>(context.op_.GetAttr<float>("max")); + thrust::counting_iterator<unsigned int> index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr<T>(data), + UniformGenerator<T>(min, max, seed)); + } +}; + +} // namespace operators +} // namespace paddle -REGISTER_OP_GPU_KERNEL( - uniform_random, - paddle::operators::UniformRandomKernel<paddle::platform::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel<float>); \ No newline at end of file diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h deleted file mode 100644 index dffa640f84..0000000000 --- a/paddle/operators/uniform_random_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -template <typename Place, typename T> -class UniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output<framework::Tensor>("Out"); - T* data = tensor->mutable_data<T>(context.GetPlace()); - T min = static_cast<T>(context.op_.GetAttr<float>("min")); - T max = static_cast<T>(context.op_.GetAttr<float>("max")); - auto n = framework::product(tensor->dims()); - - auto* device_context = - const_cast<platform::DeviceContext*>(context.device_context_); - math::RandUniform<Place, T>(n, min, max, data, device_context); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index ad9b4e42f3..ad212c5b2c 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,17 +25,8 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place, int seed) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place) { eigen_device_.reset(new Eigen::DefaultDevice()); - rand_seed_ = seed; -} - -std::minstd_rand& CPUDeviceContext::rand_engine() { - if (!rand_engine_) { - rand_engine_.reset(new std::minstd_rand()); - rand_engine_->seed(rand_seed_); - } - return *(rand_engine_.get()); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { @@ -104,8 +95,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); } -CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) - : place_(place), rand_seed_(seed) { +CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -157,19 +147,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( - curand_generator_, rand_seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); - } - return curand_generator_; -} - cudaStream_t CUDADeviceContext::stream() { return stream_; } #endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index e18f48fef5..11528e1194 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -17,7 +17,6 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,18 +39,14 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace place, int seed = 0); + explicit CPUDeviceContext(CPUPlace place); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; - std::minstd_rand& rand_engine(); - Place GetPlace() const override; private: - int rand_seed_; - std::unique_ptr<std::minstd_rand> rand_engine_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_; }; @@ -60,7 +55,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace place, uint64_t seed = 0); + explicit CUDADeviceContext(GPUPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -79,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); // clang-format on @@ -92,13 +84,10 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; - uint64_t rand_seed_; - // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index 367d21b301..f95ed70b58 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -17,7 +17,12 @@ class GaussianRandomTest(unittest.TestCase): scope.new_var("Out").get_tensor() op = Operator( - "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.) + "gaussian_random", + Out="Out", + dims=[1000, 784], + mean=.0, + std=1., + seed=10) op.infer_shape(scope) context = core.DeviceContext.create(place) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index 95c36a27cf..c3d2bb44da 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -17,7 +17,12 @@ class UniformRandomTest(unittest.TestCase): scope.new_var("X").get_tensor() op = Operator( - "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0) + "uniform_random", + Out="X", + dims=[1000, 784], + min=-5.0, + max=10.0, + seed=10) op.infer_shape(scope) ctx = core.DeviceContext.create(place) From 1918ad875980a7b5fb54c207e56d86b4376e2505 Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Tue, 22 Aug 2017 12:16:30 +0800 Subject: [PATCH 7/8] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 4 ++-- paddle/operators/math/math_function.cc | 9 --------- paddle/operators/math/math_function.cu | 13 ------------- paddle/operators/math/math_function.h | 5 ----- paddle/platform/device_context_test.cc | 2 -- 5 files changed, 2 insertions(+), 31 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 228f463f2b..ed51d416ed 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3) + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index d9824e5f96..1e86fc3d16 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,15 +109,6 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, matrix_b.data<double>(), beta, matrix_out->data<double>(), context); } -template <> -void Set<platform::CPUPlace, float>(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context); - framework::EigenVector<float>::Type out(output, n); - out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 9dff6f05fb..da40b27c94 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include <thrust/device_ptr.h> -#include <thrust/iterator/counting_iterator.h> -#include <thrust/random.h> -#include <thrust/transform.h> #include "paddle/operators/math/math_function.h" namespace paddle { @@ -126,15 +122,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, matrix_b.data<double>(), beta, matrix_out->data<double>(), context); } -template <> -void Set<platform::GPUPlace, float>(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); - framework::EigenVector<float>::Type out(output, n); - out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index a0e9660564..155589fadb 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,7 +52,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include <cmath> -#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -78,10 +77,6 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); -template <typename Place, typename T> -void Set(const int n, const T alpha, T* output, - platform::DeviceContext* context); - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8b764bdcd9..5883a55272 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); ASSERT_NE(nullptr, device_context->stream()); delete device_context; } From aff90d8ee78be398b2984d63f2eb985f15f430d1 Mon Sep 17 00:00:00 2001 From: qijun <qijun1994@hotmail.com> Date: Tue, 22 Aug 2017 04:34:35 +0000 Subject: [PATCH 8/8] fix gpu build error --- paddle/operators/gaussian_random_op.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1d312e7b5d..018a4bfcb2 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -30,7 +30,7 @@ struct GaussianGenerator { __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); - thrust::normal_distribution<T> dist(min_, max_); + thrust::normal_distribution<T> dist(mean_, std_); rng.discard(n); return dist(rng); } @@ -62,4 +62,4 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace paddle REGISTER_OP_GPU_KERNEL(gaussian_random, - paddle::operators::GPUGaussianRandomKernel<float>); \ No newline at end of file + paddle::operators::GPUGaussianRandomKernel<float>);