diff --git a/mindspore/ccsrc/pybind_api/random_normal/philox_generator.h b/mindspore/ccsrc/pybind_api/random_normal/philox_generator.h index b8045dffcf..c3862e15a6 100644 --- a/mindspore/ccsrc/pybind_api/random_normal/philox_generator.h +++ b/mindspore/ccsrc/pybind_api/random_normal/philox_generator.h @@ -32,6 +32,15 @@ class PhiloxGenerator { counter_[3] = static_cast(seed_ >> 32); } + explicit PhiloxGenerator(uint64_t seed_, uint64_t seed2_) { + key_var_[0] = static_cast(seed_); + key_var_[1] = static_cast(seed_ >> 32); + counter_[0] = 0; + counter_[1] = 0; + counter_[2] = static_cast(seed2_); + counter_[3] = static_cast(seed2_ >> 32); + } + ~PhiloxGenerator() = default; void Jump(); diff --git a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc index 445407f58c..016fdae242 100644 --- a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc +++ b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc @@ -20,7 +20,7 @@ #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { -bool InitRandomNormal(float mean, float stddev, std::vector out_shape, int64_t seed, +bool InitRandomNormal(float mean, float stddev, std::vector out_shape, int64_t seed, int64_t seed2, const py::object &output_tensor) { if (out_shape.size() == 0) { std::cout << "output data shape is error" << std::endl; @@ -41,7 +41,8 @@ bool InitRandomNormal(float mean, float stddev, std::vector out_shape, } int64_t batchSize = total_count / thread_num; std::vector threads(thread_num); - mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed); + seed = (seed == 0 && seed2 == 0) ? clock() : seed; + mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2); if (thread_num != 1) { for (uint32_t i = 0; i < thread_num - 1; i++) { float *offset_ptr = start_ptr + batchSize * i; diff --git a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h index e2625882aa..74eb7130bb 100644 --- a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h +++ b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h @@ -85,7 +85,7 @@ bool FillRandoms(PhiloxGenerator generator, float *output, int64_t vet_size, int } return true; } -bool InitRandomNormal(float mean, float stddev, std::vector out_shape, int64_t seed, +bool InitRandomNormal(float mean, float stddev, std::vector out_shape, int64_t seed, int64_t seed2, const py::object &output_tensor); } // namespace mindspore diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 84b61c6a56..45008feebe 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -45,16 +45,13 @@ class Initializer: @property def seed(self): if self._seed is None: - seed_ = get_seed() if get_seed() is not None else 1 - _, seed = _get_graph_seed(seed_, "init") + seed, seed2 = _get_graph_seed(get_seed(), "init") else: - seed = self._seed - return seed + seed, seed2 = self._seed + 1, 0 + return seed, seed2 @seed.setter def seed(self, value): - if not isinstance(value, int): - raise TypeError("'value' must be int type.") self._seed = value def _initialize(self, *kwargs): @@ -367,9 +364,9 @@ class Normal(Initializer): self.sigma = sigma def _initialize(self, arr): - seed = self.seed + seed, seed2 = self.seed output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32)) - random_normal(0, self.sigma, arr.shape, seed, output_tensor) + random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor) output_data = output_tensor.asnumpy() output_data *= self.sigma _assignment(arr, output_data) diff --git a/mindspore/common/seed.py b/mindspore/common/seed.py index 8e5f370404..fd17cd5b35 100644 --- a/mindspore/common/seed.py +++ b/mindspore/common/seed.py @@ -18,6 +18,7 @@ import mindspore.dataset as de from mindspore._checkparam import Validator # constants +DEFAULT_GRAPH_SEED = 87654321 _MAXINT32 = 2**31 - 1 keyConstant = [3528531795, 2654435769, 3449720151, 3144134277] @@ -210,7 +211,9 @@ def _get_graph_seed(op_seed, kernel_name): >>> _get_graph_seed(seed, 'normal') """ global_seed = get_seed() - if global_seed is None: + if global_seed == 0: + global_seed = DEFAULT_GRAPH_SEED + elif global_seed is None: global_seed = 0 if op_seed is None: op_seed = 0 diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 1289c5204d..86e7a3852f 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -465,7 +465,7 @@ class MetaTensor(MetaTensor_): def __exit__(self, ptype, value, trace): if self.need_set_seed: np.random.seed(self._np_seed) - self.init.seed = self.seed + self.init.seed, _ = self.seed with seed_context(self.init): self.init(arr) diff --git a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py index 3801496dc4..72772a8151 100644 --- a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +++ b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py @@ -39,7 +39,7 @@ class WithBNNLossCell(Cell): Examples: >>> net = Net() - >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) >>> net_with_criterion_object = WithBNNLossCell(net, loss_fn) >>> net_with_criterion = net_with_criterion_object() >>> diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index f283dd834d..bdf60aa953 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -46,7 +46,7 @@ class WithLossCell(Cell): Examples: >>> net = Net() - >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) >>> net_with_criterion = nn.WithLossCell(net, loss_fn) >>> >>> batch_size = 2