fix: random_op.py pylint error

pull/4297/head
jonyguo 5 years ago
parent 4964f7703a
commit 866e9259a4

@ -136,10 +136,10 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0):
n_dist = shape(inputs)[-2]
a = Tensor(0.0, mstype.float32)
b = Tensor(1.0, mstype.float32)
uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
if n_dist != 1:
uniform = reshape(uniform, (n_dist, num_sample))
vals = P.RealDiv()(P.Log()(uniform), inputs + 1e-6)
random_uniform = reshape(random_uniform, (n_dist, num_sample))
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
_, indices = P.TopK()(vals, num_sample)
return indices
return P.Multinomial(seed=seed)(inputs, num_sample)
@ -211,8 +211,8 @@ def gamma(shape, alpha, beta, seed=0):
const_utils.check_tensors_dtype_same(beta_dtype, mstype.float32, "gamma")
seed1 = get_seed()
seed2 = seed
gamma = P.Gamma(seed1, seed2)
value = gamma(shape, alpha, beta)
random_gamma = P.Gamma(seed1, seed2)
value = random_gamma(shape, alpha, beta)
return value
def poisson(shape, mean, seed=0):
@ -238,6 +238,6 @@ def poisson(shape, mean, seed=0):
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "poisson")
seed1 = get_seed()
seed2 = seed
poisson = P.Poisson(seed1, seed2)
value = poisson(shape, mean)
random_poisson = P.Poisson(seed1, seed2)
value = random_poisson(shape, mean)
return value

Loading…
Cancel
Save