Check seed of non negative

pull/6697/head
peixu_ren 5 years ago
parent 00d254415a
commit 4aca47801c

@ -97,14 +97,13 @@ def _get_op_seed(op_seed, kernel_name):
seed (int): The op-seed to be updated.
kernel_name (string): The random op kernel.
"""
if ((kernel_name, op_seed) not in _KERNEL_SEED) or (_KERNEL_SEED[(kernel_name, op_seed)] == -1):
if (kernel_name, op_seed) not in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed
_KERNEL_SEED[(kernel_name, op_seed)] = 0
return _KERNEL_SEED[(kernel_name, op_seed)]
def _reset_op_seed():
"""
Reset op seeds in the kernel's dictionary.
"""
for key in _KERNEL_SEED:
_KERNEL_SEED[key] = -1
for (kernel_name, op_seed) in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed

@ -54,6 +54,7 @@ def get_seed(op_seed, kernel_name):
if op_seed is None:
temp_seed = _get_op_seed(0, kernel_name)
else:
const_utils.check_int_non_negative("seed", op_seed, kernel_name)
temp_seed = _get_op_seed(op_seed, kernel_name)
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
_update_seeds(op_seed, kernel_name)
@ -88,7 +89,6 @@ def normal(shape, mean, stddev, seed=None):
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
seed1, seed2 = get_seed(seed, "normal")
const_utils.check_int_non_negative("seed", seed2, "normal")
stdnormal = P.StandardNormal(seed1, seed2)
random_normal = stdnormal(shape)
value = random_normal * stddev + mean
@ -126,7 +126,6 @@ def laplace(shape, mean, lambda_param, seed=None):
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "laplace")
const_utils.check_tensors_dtype_same(lambda_param_dtype, mstype.float32, "laplace")
seed1, seed2 = get_seed(seed, "laplace")
const_utils.check_int_non_negative("seed", seed2, "laplace")
stdlaplace = P.StandardLaplace(seed1, seed2)
rnd = stdlaplace(shape)
value = rnd * lambda_param + mean
@ -177,7 +176,6 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
seed1, seed2 = get_seed(seed, "uniform")
const_utils.check_int_non_negative("seed", seed2, "uniform")
if const_utils.is_same_type(dtype, mstype.int32):
random_uniform = P.UniformInt(seed1, seed2)
value = random_uniform(shape, minval, maxval)
@ -210,7 +208,6 @@ def gamma(shape, alpha, beta, seed=None):
>>> output = C.gamma(shape, alpha, beta, seed=5)
"""
seed1, seed2 = get_seed(seed, "gamma")
const_utils.check_int_non_negative("seed", seed2, "gamma")
random_gamma = P.Gamma(seed1, seed2)
value = random_gamma(shape, alpha, beta)
return value
@ -235,7 +232,6 @@ def poisson(shape, mean, seed=None):
>>> output = C.poisson(shape, mean, seed=5)
"""
seed1, seed2 = get_seed(seed, "poisson")
const_utils.check_int_non_negative("seed", seed2, "poisson")
random_poisson = P.Poisson(seed1, seed2)
value = random_poisson(shape, mean)
return value

Loading…
Cancel
Save