|
|
|
@ -12,11 +12,9 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
"""Operators for random."""
|
|
|
|
|
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ..._checkparam import Validator, Rel
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
from .._utils import get_broadcast_shape
|
|
|
|
@ -46,16 +44,16 @@ class StandardNormal(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, seed=0, seed2=0):
|
|
|
|
|
"""Initialize StandardNormal"""
|
|
|
|
|
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
|
|
|
|
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape):
|
|
|
|
|
shape_v = shape["value"]
|
|
|
|
|
if shape_v is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape must be const.")
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
for i, shape_i in enumerate(shape_v):
|
|
|
|
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
|
|
|
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
|
|
|
|
out = {
|
|
|
|
|
'shape': shape_v,
|
|
|
|
|
'dtype': mstype.float32,
|
|
|
|
@ -91,16 +89,16 @@ class StandardLaplace(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, seed=0, seed2=0):
|
|
|
|
|
"""Initialize StandardLaplace"""
|
|
|
|
|
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
|
|
|
|
validator.check_value_type('seed', seed, [int], self.name)
|
|
|
|
|
validator.check_value_type('seed2', seed2, [int], self.name)
|
|
|
|
|
Validator.check_value_type('seed', seed, [int], self.name)
|
|
|
|
|
Validator.check_value_type('seed2', seed2, [int], self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape):
|
|
|
|
|
shape_v = shape["value"]
|
|
|
|
|
if shape_v is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape must be const.")
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
for i, shape_i in enumerate(shape_v):
|
|
|
|
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
|
|
|
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
|
|
|
|
out = {
|
|
|
|
|
'shape': shape_v,
|
|
|
|
|
'dtype': mstype.float32,
|
|
|
|
@ -143,18 +141,18 @@ class Gamma(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, seed=0, seed2=0):
|
|
|
|
|
"""Initialize Gamma"""
|
|
|
|
|
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
|
|
|
|
|
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape, alpha, beta):
|
|
|
|
|
shape_v = shape["value"]
|
|
|
|
|
if shape_v is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape must be const.")
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
for i, shape_i in enumerate(shape_v):
|
|
|
|
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
|
|
|
|
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
|
|
|
|
Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
|
|
|
|
|
Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
|
|
|
|
|
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
|
|
|
|
|
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
|
|
|
|
out = {
|
|
|
|
@ -195,17 +193,17 @@ class Poisson(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, seed=0, seed2=0):
|
|
|
|
|
"""Initialize Poisson"""
|
|
|
|
|
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
|
|
|
|
|
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape, mean):
|
|
|
|
|
shape_v = shape["value"]
|
|
|
|
|
if shape_v is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape must be const.")
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
for i, shape_i in enumerate(shape_v):
|
|
|
|
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
|
|
|
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
|
|
|
|
Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
|
|
|
|
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
|
|
|
|
|
out = {
|
|
|
|
|
'shape': broadcast_shape,
|
|
|
|
@ -251,22 +249,22 @@ class UniformInt(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, seed=0, seed2=0):
|
|
|
|
|
"""Initialize UniformInt"""
|
|
|
|
|
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
|
|
|
|
|
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape, minval, maxval):
|
|
|
|
|
shape_v = shape["value"]
|
|
|
|
|
if shape_v is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape must be const.")
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
for i, shape_i in enumerate(shape_v):
|
|
|
|
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
|
|
|
|
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
|
|
|
|
Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
|
|
|
|
|
Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
|
|
|
|
|
minval_shape = minval['shape']
|
|
|
|
|
maxval_shape = maxval['shape']
|
|
|
|
|
validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
|
|
|
|
validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
|
|
|
|
Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
|
|
|
|
Validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
|
|
|
|
out = {
|
|
|
|
|
'shape': shape_v,
|
|
|
|
|
'dtype': mstype.int32,
|
|
|
|
@ -298,16 +296,16 @@ class UniformReal(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, seed=0, seed2=0):
|
|
|
|
|
"""Initialize UniformReal"""
|
|
|
|
|
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
|
|
|
|
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape):
|
|
|
|
|
shape_v = shape["value"]
|
|
|
|
|
if shape_v is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape must be const.")
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
for i, shape_i in enumerate(shape_v):
|
|
|
|
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
|
|
|
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
|
|
|
|
out = {
|
|
|
|
|
'shape': shape_v,
|
|
|
|
|
'dtype': mstype.float32,
|
|
|
|
@ -348,18 +346,18 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, count=256, seed=0, seed2=0):
|
|
|
|
|
"""Initialize RandomChoiceWithMask"""
|
|
|
|
|
validator.check_value_type("count", count, [int], self.name)
|
|
|
|
|
validator.check_integer("count", count, 0, Rel.GT, self.name)
|
|
|
|
|
validator.check_value_type('seed', seed, [int], self.name)
|
|
|
|
|
validator.check_value_type('seed2', seed2, [int], self.name)
|
|
|
|
|
Validator.check_value_type("count", count, [int], self.name)
|
|
|
|
|
Validator.check_positive_int(count, "count", self.name)
|
|
|
|
|
Validator.check_value_type('seed', seed, [int], self.name)
|
|
|
|
|
Validator.check_value_type('seed2', seed2, [int], self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
|
|
|
|
|
validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
|
|
|
|
|
Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
|
|
|
|
|
Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
|
|
|
|
|
return ([self.count, len(x_shape)], [self.count])
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
|
|
|
|
|
Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
|
|
|
|
|
return (mstype.int32, mstype.bool_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -399,19 +397,19 @@ class RandomCategorical(PrimitiveWithInfer):
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
|
|
|
|
|
valid_values = (mstype.int32, mstype.int16, mstype.int64)
|
|
|
|
|
validator.check_type_name("dtype", dtype, valid_values, self.name)
|
|
|
|
|
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def __infer__(self, logits, num_samples, seed):
|
|
|
|
|
logits_dtype = logits['dtype']
|
|
|
|
|
valid_types = (mstype.float32, mstype.float16, mstype.float64)
|
|
|
|
|
validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
|
|
|
|
|
Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
|
|
|
|
|
num_samples_v = num_samples['value']
|
|
|
|
|
seed_v = seed['value']
|
|
|
|
|
validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
|
|
|
|
|
validator.check_value_type('seed', seed_v, (int,), self.name)
|
|
|
|
|
validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name)
|
|
|
|
|
Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
|
|
|
|
|
Validator.check_value_type('seed', seed_v, (int,), self.name)
|
|
|
|
|
Validator.check_positive_int(num_samples_v, "num_samples", self.name)
|
|
|
|
|
x_shape = list(logits['shape'])
|
|
|
|
|
if len(x_shape) != 2:
|
|
|
|
|
raise ValueError("RandomCategorical shape should be 2-dimension.")
|
|
|
|
@ -450,20 +448,20 @@ class Multinomial(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, seed=0):
|
|
|
|
|
"""init"""
|
|
|
|
|
validator.check_value_type("seed", seed, [int], self.name)
|
|
|
|
|
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
Validator.check_value_type("seed", seed, [int], self.name)
|
|
|
|
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def __infer__(self, inputs, num_samples):
|
|
|
|
|
input_shape = inputs["shape"]
|
|
|
|
|
if len(input_shape) != 1 and len(input_shape) != 2:
|
|
|
|
|
raise ValueError("input dim must be 1 or 2")
|
|
|
|
|
validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
|
|
|
|
|
Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
|
|
|
|
|
num_samples_value = num_samples["value"]
|
|
|
|
|
if num_samples_value is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, shape nust be const")
|
|
|
|
|
validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
|
|
|
|
validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None)
|
|
|
|
|
Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
|
|
|
|
Validator.check_positive_int(num_samples_value, "num_samples")
|
|
|
|
|
y_shape = (num_samples_value,)
|
|
|
|
|
if len(input_shape) == 2:
|
|
|
|
|
y_shape = (input_shape[0], num_samples_value)
|
|
|
|
|