"remove context random seeding "

revert-3824-remove_grad_op_type
dongzhihong 8 years ago
parent 6fc6647c31
commit 70825506d1

@ -21,12 +21,10 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
}
CPUDeviceContext::CPUDeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice());
}
@ -44,7 +42,6 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
}
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
SetDeviceId(place_.device);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
@ -111,8 +108,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed(
curand_generator_, random_seed_));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
}
return curand_generator_;
}

@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <chrono>
#include <memory>
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
@ -40,7 +39,6 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {
public:
typedef std::mt19937 random_generator_type;
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}
@ -49,16 +47,7 @@ class CPUDeviceContext : public DeviceContext {
Place GetPlace() const override;
random_generator_type& RandGenerator() {
if (!rand_generator_) {
rand_generator_.reset(new random_generator_type(random_seed_));
}
return *rand_generator_.get();
}
private:
unsigned random_seed_;
std::unique_ptr<random_generator_type> rand_generator_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
@ -97,7 +86,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
private:
unsigned random_seed_;
uint64_t seed_;
// clang-format off
cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;

Loading…
Cancel
Save