You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.9 KiB
61 lines
1.9 KiB
9 years ago
|
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License. */
|
||
|
|
||
|
#include "ThreadLocal.h"
|
||
|
|
||
|
#include "Thread.h"
|
||
|
|
||
|
#include "CommandLineParser.h"
|
||
|
|
||
|
P_DEFINE_bool(thread_local_rand_use_global_seed, false,
|
||
|
"Whether to use global seed in thread local rand.");
|
||
|
|
||
|
namespace paddle {
|
||
|
|
||
|
unsigned int ThreadLocalRand::defaultSeed_ = 1;
|
||
|
ThreadLocal<unsigned int> ThreadLocalRand::seed_;
|
||
|
|
||
|
unsigned int* ThreadLocalRand::getSeed() {
|
||
|
unsigned int* p = seed_.get(false /*createLocal*/);
|
||
|
if (!p) { // init seed
|
||
|
if (FLAGS_thread_local_rand_use_global_seed) {
|
||
|
p = new unsigned int(defaultSeed_);
|
||
|
} else if (getpid() == gettid()) { // main thread
|
||
|
// deterministic, but differs from global srand()
|
||
|
p = new unsigned int(defaultSeed_ - 1);
|
||
|
} else {
|
||
|
p = new unsigned int(defaultSeed_ + gettid());
|
||
|
LOG(INFO) << "thread use undeterministic rand seed:" << *p;
|
||
|
}
|
||
|
seed_.set(p);
|
||
|
}
|
||
|
return p;
|
||
|
}
|
||
|
|
||
|
ThreadLocal<std::default_random_engine> ThreadLocalRandomEngine::engine_;
|
||
|
std::default_random_engine& ThreadLocalRandomEngine::get() {
|
||
|
auto engine = engine_.get(false);
|
||
|
if (!engine) {
|
||
|
engine = new std::default_random_engine;
|
||
|
int defaultSeed = ThreadLocalRand::getDefaultSeed();
|
||
|
engine->seed(FLAGS_thread_local_rand_use_global_seed
|
||
|
? defaultSeed
|
||
|
: defaultSeed + gettid());
|
||
|
engine_.set(engine);
|
||
|
}
|
||
|
return *engine;
|
||
|
}
|
||
|
|
||
|
} // namespace paddle
|