From f19f394bb8ca47797ca6769141ddad9058a81af2 Mon Sep 17 00:00:00 2001 From: Nat Sutyanyong Date: Sat, 2 May 2020 21:27:59 -0400 Subject: [PATCH] Refactor duplicate code on random device and seed --- .../dataset/engine/datasetops/shuffle_op.cc | 16 ++---- mindspore/ccsrc/dataset/util/CMakeLists.txt | 3 +- mindspore/ccsrc/dataset/util/random.cc | 54 ------------------- mindspore/ccsrc/dataset/util/random.h | 38 ++++++++++++- mindspore/ccsrc/dataset/util/services.cc | 12 ++--- 5 files changed, 44 insertions(+), 79 deletions(-) delete mode 100644 mindspore/ccsrc/dataset/util/random.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc index 7b09bcef4d..9867945e36 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc @@ -86,20 +86,10 @@ Status ShuffleOp::SelfReset() { // epoch. // If ReshuffleEachEpoch is true, then the first epoch uses the given seed, // and all subsequent epochs will then reset the seed based on random device. - if (!reshuffle_each_epoch_) { - rng_ = std::mt19937_64(shuffle_seed_); - } else { -#if defined(_WIN32) || defined(_WIN64) - unsigned int number; - rand_s(&number); - std::mt19937 random_device{static_cast(number)}; -#else - std::random_device random_device("/dev/urandom"); -#endif - std::uniform_int_distribution distribution(0, std::numeric_limits::max()); - shuffle_seed_ = distribution(random_device); - rng_ = std::mt19937_64(shuffle_seed_); + if (reshuffle_each_epoch_) { + shuffle_seed_ = GetNewSeed(); } + rng_ = std::mt19937_64(shuffle_seed_); shuffle_buffer_ = std::make_unique(); buffer_counter_ = 0; shuffle_last_row_idx_ = 0; diff --git a/mindspore/ccsrc/dataset/util/CMakeLists.txt b/mindspore/ccsrc/dataset/util/CMakeLists.txt index 9ae93618ab..b0630f4005 100644 --- a/mindspore/ccsrc/dataset/util/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/util/CMakeLists.txt @@ -12,5 +12,4 @@ add_library(utils OBJECT status.cc path.cc wait_post.cc - sig_handler.cc - random.cc) + sig_handler.cc) diff --git a/mindspore/ccsrc/dataset/util/random.cc b/mindspore/ccsrc/dataset/util/random.cc deleted file mode 100644 index 43b3ee4afd..0000000000 --- a/mindspore/ccsrc/dataset/util/random.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_RANDOM_H_ -#define DATASET_UTIL_RANDOM_H_ - -#include "dataset/util/random.h" - -#if defined(_WIN32) || defined(_WIn64) -#include -#endif -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -uint32_t GetSeed() { - uint32_t seed = GlobalContext::config_manager()->seed(); - if (seed == std::mt19937::default_seed) { -#if defined(_WIN32) || defined(_WIN64) - unsigned int number; - rand_s(&number); - std::mt19937 random_device{static_cast(number)}; -#else - std::random_device random_device("/dev/urandom"); -#endif - std::uniform_int_distribution distribution(0, std::numeric_limits::max()); - seed = distribution(random_device); - } - - return seed; -} -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_RANDOM_H_ diff --git a/mindspore/ccsrc/dataset/util/random.h b/mindspore/ccsrc/dataset/util/random.h index fa9e18f707..6c70d6c7ef 100644 --- a/mindspore/ccsrc/dataset/util/random.h +++ b/mindspore/ccsrc/dataset/util/random.h @@ -15,9 +15,45 @@ */ #ifndef DATASET_UTIL_RANDOM_H_ #define DATASET_UTIL_RANDOM_H_ + +#if defined(_WIN32) || defined(_WIN64) +#include +#endif +#include +#include +#include +#include + +#include "dataset/core/config_manager.h" +#include "dataset/core/global_context.h" + namespace mindspore { namespace dataset { -uint32_t GetSeed(); +inline std::mt19937 GetRandomDevice() { +#if defined(_WIN32) || defined(_WIN64) + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; +#else + std::mt19937 random_device{std::random_device("/dev/urandom")()}; +#endif + return random_device; +} + +inline uint32_t GetNewSeed() { + std::mt19937 random_device = GetRandomDevice(); + std::uniform_int_distribution distribution(0, std::numeric_limits::max()); + return distribution(random_device); +} + +inline uint32_t GetSeed() { + uint32_t seed = GlobalContext::config_manager()->seed(); + if (seed == std::mt19937::default_seed) { + seed = GetNewSeed(); + } + return seed; +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/services.cc b/mindspore/ccsrc/dataset/util/services.cc index a2b3f734c2..6516deea41 100644 --- a/mindspore/ccsrc/dataset/util/services.cc +++ b/mindspore/ccsrc/dataset/util/services.cc @@ -22,8 +22,8 @@ #include #endif #include -#include #include "dataset/util/circular_pool.h" +#include "dataset/util/random.h" #include "dataset/util/task_manager.h" #define SLOT_TASK_MGR 0 @@ -50,14 +50,8 @@ int Services::GetLWP() { return syscall(SYS_gettid); } std::string Services::GetUniqueID() { const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; -#if defined(_WIN32) || defined(_WIN64) - unsigned int number; - rand_s(&number); - std::mt19937 gen{static_cast(number)}; -#else - std::mt19937 gen{std::random_device{"/dev/urandom"}()}; -#endif - std::uniform_int_distribution<> dist(0, kStr.size() - 1); + std::mt19937 gen = GetRandomDevice(); + std::uniform_int_distribution dist(0, kStr.size() - 1); char buffer[UNIQUEID_LEN]; for (int i = 0; i < UNIQUEID_LEN; i++) { buffer[i] = kStr[dist(gen)];