Merge pull request #5945 from wanghaoshuang/sampler
	
		
	
				
					
				
			Add math function for sampling integersadd_depthwiseConv_op_gpu
						commit
						32cc11e358
					
				@ -0,0 +1,70 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "sampler.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace random {
 | 
				
			||||
 | 
				
			||||
Sampler::~Sampler() {}
 | 
				
			||||
 | 
				
			||||
UniformSampler::UniformSampler(int64 range)
 | 
				
			||||
    : Sampler(range), inv_range_(1.0 / range) {
 | 
				
			||||
  random_engine_ = std::make_shared<std::mt19937>(seed_);
 | 
				
			||||
  dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
UniformSampler::UniformSampler(int64 range, unsigned int seed)
 | 
				
			||||
    : Sampler(range, seed), inv_range_(1.0 / range) {
 | 
				
			||||
  random_engine_ = std::make_shared<std::mt19937>(seed_);
 | 
				
			||||
  dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
 | 
				
			||||
 | 
				
			||||
float UniformSampler::Probability(int64 value) const { return inv_range_; }
 | 
				
			||||
 | 
				
			||||
LogUniformSampler::LogUniformSampler(int64 range)
 | 
				
			||||
    : Sampler(range), log_range_(log(range + 1)) {
 | 
				
			||||
  random_engine_ = std::make_shared<std::mt19937>(seed_);
 | 
				
			||||
  dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed)
 | 
				
			||||
    : Sampler(range, seed), log_range_(log(range + 1)) {
 | 
				
			||||
  random_engine_ = std::make_shared<std::mt19937>(seed_);
 | 
				
			||||
  dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
 | 
				
			||||
}
 | 
				
			||||
int64 LogUniformSampler::Sample() const {
 | 
				
			||||
  // Got Log Uniform distribution from uniform distribution by
 | 
				
			||||
  // inverse_transform_sampling method
 | 
				
			||||
  // More details:
 | 
				
			||||
  // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
 | 
				
			||||
  const int64 value =
 | 
				
			||||
      static_cast<int64>(exp((*dist_)(*random_engine_) * log_range_)) - 1;
 | 
				
			||||
  // Mathematically, value should be <= range_, but might not be due to some
 | 
				
			||||
  // floating point roundoff, so we mod by range_.
 | 
				
			||||
  return value % range_;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
float LogUniformSampler::Probability(int64 value) const {
 | 
				
			||||
  // Given f(x) = 1/[(x+1) * log_range_]
 | 
				
			||||
  // The value's  probability  is integral of f(x) from value to (value + 1)
 | 
				
			||||
  // More details:
 | 
				
			||||
  // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler
 | 
				
			||||
  return (log((value + 2.0) / (value + 1.0))) / log_range_;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace random
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,100 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <random>
 | 
				
			||||
typedef long int64;
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
namespace math {
 | 
				
			||||
 | 
				
			||||
// TODO(wanghaoshuang): Support for GPU
 | 
				
			||||
 | 
				
			||||
/**
 | 
				
			||||
* Sample integers from [0, range).
 | 
				
			||||
*/
 | 
				
			||||
class Sampler {
 | 
				
			||||
 public:
 | 
				
			||||
  explicit Sampler(int64 range) : range_(range) {
 | 
				
			||||
    PADDLE_ENFORCE_GT(range, 0);
 | 
				
			||||
    std::random_device r;
 | 
				
			||||
    seed_ = r();
 | 
				
			||||
  }
 | 
				
			||||
  explicit Sampler(int64 range, unsigned int seed)
 | 
				
			||||
      : range_(range), seed_(seed) {
 | 
				
			||||
    PADDLE_ENFORCE_GT(range, 0);
 | 
				
			||||
  }
 | 
				
			||||
  virtual ~Sampler();
 | 
				
			||||
  // Sample a single value
 | 
				
			||||
  virtual int64 Sample() const = 0;
 | 
				
			||||
  // The probability that a single call to Sample() returns the given value.
 | 
				
			||||
  virtual float Probability(int64 value) const = 0;
 | 
				
			||||
 | 
				
			||||
  int64 range() { return range_; };
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  const int64 range_;
 | 
				
			||||
  unsigned int seed_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
/**
 | 
				
			||||
 * Sample integers from [0, range).
 | 
				
			||||
 * And the distribution function is:
 | 
				
			||||
 * P(x) = 1 / range
 | 
				
			||||
 */
 | 
				
			||||
class UniformSampler : public Sampler {
 | 
				
			||||
 public:
 | 
				
			||||
  explicit UniformSampler(int64 range);
 | 
				
			||||
 | 
				
			||||
  explicit UniformSampler(int64 range, unsigned int seed);
 | 
				
			||||
 | 
				
			||||
  ~UniformSampler() override {}
 | 
				
			||||
 | 
				
			||||
  int64 Sample() const override;
 | 
				
			||||
 | 
				
			||||
  float Probability(int64 value) const override;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  const float inv_range_;
 | 
				
			||||
  std::shared_ptr<std::mt19937_64> random_engine_;
 | 
				
			||||
  std::shared_ptr<std::uniform_int_distribution<>> dist_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
/**
 | 
				
			||||
 * Sample integers from [0, range).
 | 
				
			||||
 * And the distribution function is:
 | 
				
			||||
 * P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1))
 | 
				
			||||
 */
 | 
				
			||||
class LogUniformSampler : public Sampler {
 | 
				
			||||
 public:
 | 
				
			||||
  explicit LogUniformSampler(int64 range);
 | 
				
			||||
 | 
				
			||||
  explicit LogUniformSampler(int64 range, unsigned int seed);
 | 
				
			||||
 | 
				
			||||
  ~LogUniformSampler() override {}
 | 
				
			||||
 | 
				
			||||
  int64 Sample() const override;
 | 
				
			||||
 | 
				
			||||
  float Probability(int64 value) const override;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  const float log_range_;
 | 
				
			||||
  std::shared_ptr<std::mt19937_64> random_engine_;
 | 
				
			||||
  std::shared_ptr<std::uniform_real_distribution<>> dist_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // math
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue