Refine code

1. Add copyright info
2. Overload structure for customized random seed
add_depthwiseConv_op_gpu
wanghaoshuang 8 years ago
parent 16ed4a92a5
commit 62efc896e1

@ -1,3 +1,17 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "sampler.h" #include "sampler.h"
namespace paddle { namespace paddle {
@ -7,8 +21,13 @@ Sampler::~Sampler() {}
UniformSampler::UniformSampler(int64 range) UniformSampler::UniformSampler(int64 range)
: Sampler(range), inv_range_(1.0 / range) { : Sampler(range), inv_range_(1.0 / range) {
std::random_device r; random_engine_ = std::make_shared<std::mt19937>(seed_);
random_engine_ = std::make_shared<std::mt19937>(r()); dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
}
UniformSampler::UniformSampler(int64 range, unsigned int seed)
: Sampler(range, seed), inv_range_(1.0 / range) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
} }
@ -18,11 +37,15 @@ float UniformSampler::Probability(int64 value) const { return inv_range_; }
LogUniformSampler::LogUniformSampler(int64 range) LogUniformSampler::LogUniformSampler(int64 range)
: Sampler(range), log_range_(log(range + 1)) { : Sampler(range), log_range_(log(range + 1)) {
std::random_device r; random_engine_ = std::make_shared<std::mt19937>(seed_);
random_engine_ = std::make_shared<std::mt19937>(r());
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
} }
LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed)
: Sampler(range, seed), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
}
int64 LogUniformSampler::Sample() const { int64 LogUniformSampler::Sample() const {
// Got Log Uniform distribution from uniform distribution by // Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method // inverse_transform_sampling method

@ -20,14 +20,21 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// TODO: Support for GPU // TODO(wanghaoshuang): Support for GPU
/** /**
* Sample integers from [0, range). * Sample integers from [0, range).
*/ */
class Sampler { class Sampler {
public: public:
explicit Sampler(int64 range) : range_(range) { /* check range > 0*/ explicit Sampler(int64 range) : range_(range) {
PADDLE_ENFORCE_GT(range, 0);
std::random_device r;
seed_ = r();
}
explicit Sampler(int64 range, unsigned int seed)
: range_(range), seed_(seed) {
PADDLE_ENFORCE_GT(range, 0);
} }
virtual ~Sampler(); virtual ~Sampler();
// Sample a single value // Sample a single value
@ -39,6 +46,7 @@ class Sampler {
protected: protected:
const int64 range_; const int64 range_;
unsigned int seed_;
}; };
/** /**
@ -50,6 +58,8 @@ class UniformSampler : public Sampler {
public: public:
explicit UniformSampler(int64 range); explicit UniformSampler(int64 range);
explicit UniformSampler(int64 range, unsigned int seed);
~UniformSampler() override {} ~UniformSampler() override {}
int64 Sample() const override; int64 Sample() const override;
@ -71,6 +81,8 @@ class LogUniformSampler : public Sampler {
public: public:
explicit LogUniformSampler(int64 range); explicit LogUniformSampler(int64 range);
explicit LogUniformSampler(int64 range, unsigned int seed);
~LogUniformSampler() override {} ~LogUniformSampler() override {}
int64 Sample() const override; int64 Sample() const override;

Loading…
Cancel
Save