From 4aca47801cd080ffbcd59dddf292bc7c92bf2402 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Mon, 21 Sep 2020 22:06:35 -0400 Subject: [PATCH] Check seed of non negative --- mindspore/common/seed.py | 7 +++---- mindspore/ops/composite/random_ops.py | 6 +----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/mindspore/common/seed.py b/mindspore/common/seed.py index 4117438e5e..44cfcdecda 100644 --- a/mindspore/common/seed.py +++ b/mindspore/common/seed.py @@ -97,14 +97,13 @@ def _get_op_seed(op_seed, kernel_name): seed (int): The op-seed to be updated. kernel_name (string): The random op kernel. """ - if ((kernel_name, op_seed) not in _KERNEL_SEED) or (_KERNEL_SEED[(kernel_name, op_seed)] == -1): + if (kernel_name, op_seed) not in _KERNEL_SEED: _KERNEL_SEED[(kernel_name, op_seed)] = op_seed - _KERNEL_SEED[(kernel_name, op_seed)] = 0 return _KERNEL_SEED[(kernel_name, op_seed)] def _reset_op_seed(): """ Reset op seeds in the kernel's dictionary. """ - for key in _KERNEL_SEED: - _KERNEL_SEED[key] = -1 + for (kernel_name, op_seed) in _KERNEL_SEED: + _KERNEL_SEED[(kernel_name, op_seed)] = op_seed diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index a85a2d165a..4a41cb4507 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -54,6 +54,7 @@ def get_seed(op_seed, kernel_name): if op_seed is None: temp_seed = _get_op_seed(0, kernel_name) else: + const_utils.check_int_non_negative("seed", op_seed, kernel_name) temp_seed = _get_op_seed(op_seed, kernel_name) seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) _update_seeds(op_seed, kernel_name) @@ -88,7 +89,6 @@ def normal(shape, mean, stddev, seed=None): const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal") const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal") seed1, seed2 = get_seed(seed, "normal") - const_utils.check_int_non_negative("seed", seed2, "normal") stdnormal = P.StandardNormal(seed1, seed2) random_normal = stdnormal(shape) value = random_normal * stddev + mean @@ -126,7 +126,6 @@ def laplace(shape, mean, lambda_param, seed=None): const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "laplace") const_utils.check_tensors_dtype_same(lambda_param_dtype, mstype.float32, "laplace") seed1, seed2 = get_seed(seed, "laplace") - const_utils.check_int_non_negative("seed", seed2, "laplace") stdlaplace = P.StandardLaplace(seed1, seed2) rnd = stdlaplace(shape) value = rnd * lambda_param + mean @@ -177,7 +176,6 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform") seed1, seed2 = get_seed(seed, "uniform") - const_utils.check_int_non_negative("seed", seed2, "uniform") if const_utils.is_same_type(dtype, mstype.int32): random_uniform = P.UniformInt(seed1, seed2) value = random_uniform(shape, minval, maxval) @@ -210,7 +208,6 @@ def gamma(shape, alpha, beta, seed=None): >>> output = C.gamma(shape, alpha, beta, seed=5) """ seed1, seed2 = get_seed(seed, "gamma") - const_utils.check_int_non_negative("seed", seed2, "gamma") random_gamma = P.Gamma(seed1, seed2) value = random_gamma(shape, alpha, beta) return value @@ -235,7 +232,6 @@ def poisson(shape, mean, seed=None): >>> output = C.poisson(shape, mean, seed=5) """ seed1, seed2 = get_seed(seed, "poisson") - const_utils.check_int_non_negative("seed", seed2, "poisson") random_poisson = P.Poisson(seed1, seed2) value = random_poisson(shape, mean) return value