modify_normal_seed

pull/8271/head
lilei 4 years ago
parent b3855530e3
commit 4d48049b27

@ -32,6 +32,15 @@ class PhiloxGenerator {
counter_[3] = static_cast<uint32_t>(seed_ >> 32);
}
explicit PhiloxGenerator(uint64_t seed_, uint64_t seed2_) {
key_var_[0] = static_cast<uint32_t>(seed_);
key_var_[1] = static_cast<uint32_t>(seed_ >> 32);
counter_[0] = 0;
counter_[1] = 0;
counter_[2] = static_cast<uint32_t>(seed2_);
counter_[3] = static_cast<uint32_t>(seed2_ >> 32);
}
~PhiloxGenerator() = default;
void Jump();

@ -20,7 +20,7 @@
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed,
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2,
const py::object &output_tensor) {
if (out_shape.size() == 0) {
std::cout << "output data shape is error" << std::endl;
@ -41,7 +41,8 @@ bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape,
}
int64_t batchSize = total_count / thread_num;
std::vector<std::thread> threads(thread_num);
mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed);
seed = (seed == 0 && seed2 == 0) ? clock() : seed;
mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2);
if (thread_num != 1) {
for (uint32_t i = 0; i < thread_num - 1; i++) {
float *offset_ptr = start_ptr + batchSize * i;

@ -85,7 +85,7 @@ bool FillRandoms(PhiloxGenerator generator, float *output, int64_t vet_size, int
}
return true;
}
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed,
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2,
const py::object &output_tensor);
} // namespace mindspore

@ -45,16 +45,13 @@ class Initializer:
@property
def seed(self):
if self._seed is None:
seed_ = get_seed() if get_seed() is not None else 1
_, seed = _get_graph_seed(seed_, "init")
seed, seed2 = _get_graph_seed(get_seed(), "init")
else:
seed = self._seed
return seed
seed, seed2 = self._seed + 1, 0
return seed, seed2
@seed.setter
def seed(self, value):
if not isinstance(value, int):
raise TypeError("'value' must be int type.")
self._seed = value
def _initialize(self, *kwargs):
@ -367,9 +364,9 @@ class Normal(Initializer):
self.sigma = sigma
def _initialize(self, arr):
seed = self.seed
seed, seed2 = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, output_tensor)
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
output_data = output_tensor.asnumpy()
output_data *= self.sigma
_assignment(arr, output_data)

@ -18,6 +18,7 @@ import mindspore.dataset as de
from mindspore._checkparam import Validator
# constants
DEFAULT_GRAPH_SEED = 87654321
_MAXINT32 = 2**31 - 1
keyConstant = [3528531795, 2654435769, 3449720151, 3144134277]
@ -210,7 +211,9 @@ def _get_graph_seed(op_seed, kernel_name):
>>> _get_graph_seed(seed, 'normal')
"""
global_seed = get_seed()
if global_seed is None:
if global_seed == 0:
global_seed = DEFAULT_GRAPH_SEED
elif global_seed is None:
global_seed = 0
if op_seed is None:
op_seed = 0

@ -465,7 +465,7 @@ class MetaTensor(MetaTensor_):
def __exit__(self, ptype, value, trace):
if self.need_set_seed:
np.random.seed(self._np_seed)
self.init.seed = self.seed
self.init.seed, _ = self.seed
with seed_context(self.init):
self.init(arr)

@ -39,7 +39,7 @@ class WithBNNLossCell(Cell):
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> net_with_criterion_object = WithBNNLossCell(net, loss_fn)
>>> net_with_criterion = net_with_criterion_object()
>>>

@ -46,7 +46,7 @@ class WithLossCell(Cell):
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
>>>
>>> batch_size = 2

Loading…
Cancel
Save