!4082 Added notation for graph-level seed access interfaces

Merge pull request !4082 from peixu_ren/master
pull/4082/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 68fc7c2c1f

@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..._checkparam import Validator as validator
from ..._checkparam import check_int_positive
from ..._checkparam import Rel
# set graph-level RNG seed
@ -29,11 +30,36 @@ _GRAPH_SEED = 0
@constexpr
def set_seed(seed):
"""
Set the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
random seed.
Args:
seed(Int): the graph-level seed value that to be set.
Examples:
>>> C.set_seed(10)
"""
check_int_positive(seed)
global _GRAPH_SEED
_GRAPH_SEED = seed
@constexpr
def get_seed():
"""
Get the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
random seed.
Returns:
Interger. The current graph-level seed.
Examples:
>>> C.get_seed(10)
"""
return _GRAPH_SEED
@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0):
>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.normal(shape, mean, stddev, seed=5)
"""
mean_dtype = F.dtype(mean)

Loading…
Cancel
Save