support generator seed in related kernals test=develop (#26495)

test_feature_precision_test_c
yaoxuefeng 5 years ago committed by GitHub
parent ae4724cfd1
commit efee426742
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -123,7 +123,7 @@ cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_t
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator)
if (WITH_GPU)
nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3)
else()

@ -61,7 +61,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv generator)
cc_test(communicator_test SRCS communicator_test.cc DEPS communicator)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc

@ -28,6 +28,7 @@
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/selected_rows.h"
@ -96,7 +97,12 @@ class UniformInitializer : public Initializer {
dist_ = std::uniform_real_distribution<float>(min_, max_);
}
float GetValue() override { return dist_(random_engine_); }
float GetValue() override {
return framework::Generator::GetInstance()->is_init_py
? dist_(framework::Generator::GetInstance()->GetCPUEngine())
: dist_(random_engine_);
// return dist_(random_engine_);
}
private:
float min_;
@ -141,7 +147,12 @@ class GaussianInitializer : public Initializer {
dist_ = std::normal_distribution<float>(mean_, std_);
}
float GetValue() override { return dist_(random_engine_); }
float GetValue() override {
return framework::Generator::GetInstance()->is_init_py
? dist_(framework::Generator::GetInstance()->GetCPUEngine())
: dist_(random_engine_);
// return dist_(random_engine_);
}
private:
float std_;

@ -18,6 +18,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
@ -55,6 +56,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
return;
}
bool init_generator_py = framework::Generator::GetInstance()->is_init_py;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
@ -71,7 +74,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
float cur_random =
init_generator_py
? dist(framework::Generator::GetInstance()->GetCPUEngine())
: dist(engine);
if (cur_random < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {

@ -14,6 +14,7 @@ limitations under the License. */
#include <random>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN
@ -31,23 +32,30 @@ class CPUGaussianRandomKernel : public framework::OpKernel<T> {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
const std::string op_type = "gaussian_random";
auto shape = GetShape(context, op_type);
tensor->Resize(shape);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
}
};

@ -18,6 +18,7 @@ limitations under the License. */
#include <queue>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
namespace operators {
@ -31,7 +32,12 @@ UniformSampler::UniformSampler(int64_t range, unsigned int seed)
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
}
int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
int64_t UniformSampler::Sample() const {
return framework::Generator::GetInstance()->is_init_py
? (*dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*dist_)(*random_engine_);
// return (*dist_)(*random_engine_);
}
float UniformSampler::Probability(int64_t value) const { return inv_range_; }
@ -46,8 +52,11 @@ int64_t LogUniformSampler::Sample() const {
// inverse_transform_sampling method
// More details:
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
const int64_t value =
static_cast<int64_t>(exp((*dist_)(*random_engine_) * log_range_)) - 1;
auto cur_random =
framework::Generator::GetInstance()->is_init_py
? (*dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*dist_)(*random_engine_);
const int64_t value = static_cast<int64_t>(exp(cur_random * log_range_)) - 1;
// Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_.
return value % range_;
@ -75,8 +84,14 @@ CustomSampler::CustomSampler(int64_t range, const float *probabilities,
}
int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
auto index =
framework::Generator::GetInstance()->is_init_py
? (*int_dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*int_dist_)(*random_engine_);
auto p =
framework::Generator::GetInstance()->is_init_py
? (*real_dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*real_dist_)(*random_engine_);
if (p > alias_probs_[index]) {
int alias = alias_[index];

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"
@ -28,21 +29,29 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
const std::string op_type = "gaussian_random";
auto shape = GetShape(context, op_type);
tensor->Resize(shape);
T* data = tensor->mutable_data<T>(context.GetPlace());
int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
std::normal_distribution<T> dist(mean, std);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
tensor->set_layout(DataLayout::kMKLDNN);

@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
@ -43,15 +44,25 @@ class CPURandintKernel : public framework::OpKernel<T> {
T* data = out->mutable_data<T>(ctx.GetPlace());
int64_t size = out->numel();
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_int_distribution<T> dist(ctx.Attr<int>("low"),
ctx.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) data[i] = dist(gen_engine);
} else {
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
}
};

@ -19,6 +19,7 @@ limitations under the License. */
#include <ctime>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
@ -31,11 +32,17 @@ static inline void random_permate(T* data_ptr, int num, unsigned int seed) {
for (int i = 0; i < num; ++i) {
data_ptr[i] = static_cast<T>(i);
}
if (seed == 0) {
seed = std::random_device()();
if (framework::Generator::GetInstance()->is_init_py) {
std::shuffle(data_ptr, data_ptr + num,
framework::Generator::GetInstance()->GetCPUEngine());
} else {
if (seed == 0) {
seed = std::random_device()();
}
std::srand(seed);
std::random_shuffle(data_ptr, data_ptr + num);
}
std::srand(seed);
std::random_shuffle(data_ptr, data_ptr + num);
}
template <typename DeviceContext, typename T>
@ -51,6 +58,7 @@ class RandpermKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(ctx.GetPlace())) {
T* out_data = out_tensor->mutable_data<T>(platform::CPUPlace());
random_permate<T>(out_data, n, seed);
} else {
framework::Tensor tmp_tensor;
tmp_tensor.Resize(framework::make_ddim({n}));

@ -21,6 +21,7 @@
#include <sstream>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
@ -61,7 +62,9 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<int64_t> ids(batch_size);
for (int i = 0; i < batch_size; ++i) {
T r = dist(engine);
T r = framework::Generator::GetInstance()->is_init_py
? dist(framework::Generator::GetInstance()->GetCPUEngine())
: dist(engine);
int idx = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {

@ -14,6 +14,7 @@ limitations under the License. */
#include <limits>
#include <random>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
@ -161,18 +162,27 @@ class CPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
1.0);
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(engine));
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(gen_engine));
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(engine));
}
}
}
};

@ -29,7 +29,7 @@ class Generator(object):
seed_in = default_rng_seed_val
if self.device == "CPU":
self.generator = core.Generator()
self.generator.manual_seed(seed_in)
# self.generator.manual_seed(seed_in)
else:
raise ValueError(
"generator class with device %s does not exist, currently only support generator with device 'CPU' "

@ -224,7 +224,8 @@ def _expand_bbox_targets(bbox_targets_input, class_nums, is_cls_agnostic):
class TestGenerateProposalLabelsOp(OpTest):
def set_data(self):
self.use_random = False
#self.use_random = False
self.init_use_random()
self.init_test_cascade()
self.init_test_params()
self.init_test_input()
@ -267,6 +268,9 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_cascade(self, ):
self.is_cascade_rcnn = False
def init_use_random(self):
self.use_random = False
def init_test_params(self):
self.batch_size_per_im = 512
self.fg_fraction = 0.25
@ -329,6 +333,28 @@ class TestCascade(TestGenerateProposalLabelsOp):
self.is_cascade_rcnn = True
class TestUseRandom(TestGenerateProposalLabelsOp):
def init_use_random(self):
self.use_random = True
self.is_cascade_rcnn = False
def test_check_output(self):
self.check_output_customized(self.verify_out)
def verify_out(self, outs):
print("skip")
def init_test_params(self):
self.batch_size_per_im = 512
self.fg_fraction = 0.025
self.fg_thresh = 0.5
self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0
self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2]
self.is_cls_agnostic = False
self.class_nums = 2 if self.is_cls_agnostic else 81
class TestClsAgnostic(TestCascade):
def init_test_params(self):
self.batch_size_per_im = 512

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save