|
|
|
@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import check_int_positive
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
|
|
|
|
|
# set graph-level RNG seed
|
|
|
|
@ -29,11 +30,36 @@ _GRAPH_SEED = 0
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def set_seed(seed):
|
|
|
|
|
"""
|
|
|
|
|
Set the graph-level seed.
|
|
|
|
|
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
|
|
|
|
|
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
|
|
|
|
|
random seed.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
seed(Int): the graph-level seed value that to be set.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> C.set_seed(10)
|
|
|
|
|
"""
|
|
|
|
|
check_int_positive(seed)
|
|
|
|
|
global _GRAPH_SEED
|
|
|
|
|
_GRAPH_SEED = seed
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def get_seed():
|
|
|
|
|
"""
|
|
|
|
|
Get the graph-level seed.
|
|
|
|
|
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
|
|
|
|
|
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
|
|
|
|
|
random seed.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Interger. The current graph-level seed.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> C.get_seed(10)
|
|
|
|
|
"""
|
|
|
|
|
return _GRAPH_SEED
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0):
|
|
|
|
|
>>> shape = (4, 16)
|
|
|
|
|
>>> mean = Tensor(1.0, mstype.float32)
|
|
|
|
|
>>> stddev = Tensor(1.0, mstype.float32)
|
|
|
|
|
>>> C.set_seed(10)
|
|
|
|
|
>>> output = C.normal(shape, mean, stddev, seed=5)
|
|
|
|
|
"""
|
|
|
|
|
mean_dtype = F.dtype(mean)
|
|
|
|
|