From 3820533ad176ef82f82e6f5f22eec618d5ad15d0 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Fri, 7 Aug 2020 14:41:53 -0400 Subject: [PATCH] Refactor Gamma and Poisson ops --- mindspore/ops/composite/__init__.py | 6 +- mindspore/ops/composite/random_ops.py | 64 +++++++++++++++++-- .../test_compoite_random_ops/test_gamma.py | 56 ++++++++++++++++ .../test_normal.py | 3 - .../test_compoite_random_ops/test_poisson.py | 54 ++++++++++++++++ .../test_uniform.py | 1 - tests/ut/python/ops/test_ops.py | 8 +-- 7 files changed, 177 insertions(+), 15 deletions(-) create mode 100644 tests/st/ops/ascend/test_compoite_random_ops/test_gamma.py rename tests/st/ops/ascend/{test_aicpu_ops => test_compoite_random_ops}/test_normal.py (99%) create mode 100644 tests/st/ops/ascend/test_compoite_random_ops/test_poisson.py rename tests/st/ops/ascend/{test_aicpu_ops => test_compoite_random_ops}/test_uniform.py (99%) diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index 6656dafdb4..b06a2a397e 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .clip_ops import clip_by_value from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like -from .random_ops import set_seed, normal, multinomial, uniform +from .random_ops import set_seed, normal, uniform, gamma, poisson, multinomial __all__ = [ @@ -49,7 +49,9 @@ __all__ = [ 'ones_like', 'zip_operation', 'set_seed', - 'uniform', 'normal', + 'uniform', + 'gamma', + 'poisson', 'multinomial', 'clip_by_value',] diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index 70313e72e0..b052e7a795 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -66,7 +66,6 @@ def get_seed(): def normal(shape, mean, stddev, seed=0): """ Generates random numbers according to the Normal (or Gaussian) random number distribution. - It is defined as: Args: shape (tuple): The shape of random tensor to be generated. @@ -84,7 +83,6 @@ def normal(shape, mean, stddev, seed=0): >>> shape = (4, 16) >>> mean = Tensor(1.0, mstype.float32) >>> stddev = Tensor(1.0, mstype.float32) - >>> C.set_seed(10) >>> output = C.normal(shape, mean, stddev, seed=5) """ mean_dtype = F.dtype(mean) @@ -148,8 +146,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0): def uniform(shape, a, b, seed=0, dtype=mstype.float32): """ - Generates random numbers according to the Uniform (or Gaussian) random number distribution. - It is defined as: + Generates random numbers according to the Uniform random number distribution. Args: shape (tuple): The shape of random tensor to be generated. @@ -170,7 +167,6 @@ def uniform(shape, a, b, seed=0, dtype=mstype.float32): >>> shape = (4, 16) >>> a = Tensor(1.0, mstype.float32) >>> b = Tensor(1.0, mstype.float32) - >>> C.set_seed(10) >>> output = C.uniform(shape, a, b, seed=5) """ a_dtype = F.dtype(a) @@ -187,3 +183,61 @@ def uniform(shape, a, b, seed=0, dtype=mstype.float32): rnd = uniform_real(shape) value = rnd * (b - a) + a return value + +def gamma(shape, alpha, beta, seed=0): + """ + Generates random numbers according to the Gamma random number distribution. + + Args: + shape (tuple): The shape of random tensor to be generated. + alpha (Tensor): The alpha α distribution parameter. With float32 data type. + beta (Tensor): The beta β distribution parameter. With float32 data type. + seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Returns: + Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta. + The dtype is float32. + + Examples: + >>> shape = (4, 16) + >>> alpha = Tensor(1.0, mstype.float32) + >>> beta = Tensor(1.0, mstype.float32) + >>> output = C.gamma(shape, alpha, beta, seed=5) + """ + alpha_dtype = F.dtype(alpha) + beta_dtype = F.dtype(beta) + const_utils.check_tensors_dtype_same(alpha_dtype, mstype.float32, "gamma") + const_utils.check_tensors_dtype_same(beta_dtype, mstype.float32, "gamma") + seed1 = get_seed() + seed2 = seed + gamma = P.Gamma(seed1, seed2) + value = gamma(shape, alpha, beta) + return value + +def poisson(shape, mean, seed=0): + """ + Generates random numbers according to the Poisson random number distribution. + + Args: + shape (tuple): The shape of random tensor to be generated. + mean (Tensor): The mean μ distribution parameter. With float32 data type. + seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Returns: + Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean. + The dtype is float32. + + Examples: + >>> shape = (4, 16) + >>> mean = Tensor(1.0, mstype.float32) + >>> output = C.poisson(shape, mean, seed=5) + """ + mean_dtype = F.dtype(mean) + const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "poisson") + seed1 = get_seed() + seed2 = seed + poisson = P.Poisson(seed1, seed2) + value = poisson(shape, mean) + return value diff --git a/tests/st/ops/ascend/test_compoite_random_ops/test_gamma.py b/tests/st/ops/ascend/test_compoite_random_ops/test_gamma.py new file mode 100644 index 0000000000..e762aedc21 --- /dev/null +++ b/tests/st/ops/ascend/test_compoite_random_ops/test_gamma.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, alpha, beta): + C.set_seed(20) + return C.gamma(self.shape, alpha, beta, self.seed) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + alpha = 1.0 + beta = 1.0 + net = Net(shape, seed) + talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) + output = net(talpha, tbeta) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + beta = np.array([1.0]).astype(np.float32) + net = Net(shape, seed) + talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) + output = net(talpha, tbeta) + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py b/tests/st/ops/ascend/test_compoite_random_ops/test_normal.py similarity index 99% rename from tests/st/ops/ascend/test_aicpu_ops/test_normal.py rename to tests/st/ops/ascend/test_compoite_random_ops/test_normal.py index 01ecbf5ec4..6c6e07b584 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py +++ b/tests/st/ops/ascend/test_compoite_random_ops/test_normal.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - import numpy as np -import pytest import mindspore.context as context import mindspore.nn as nn @@ -56,4 +54,3 @@ def test_net_ND(): tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) output = net(tmean, tstddev) assert output.shape == (3, 2, 2) - diff --git a/tests/st/ops/ascend/test_compoite_random_ops/test_poisson.py b/tests/st/ops/ascend/test_compoite_random_ops/test_poisson.py new file mode 100644 index 0000000000..caa0a1f642 --- /dev/null +++ b/tests/st/ops/ascend/test_compoite_random_ops/test_poisson.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, mean): + C.set_seed(20) + return C.poisson(self.shape, mean, self.seed) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + mean = 1.0 + net = Net(shape, seed) + tmean = Tensor(mean, mstype.float32) + output = net(tmean) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + net = Net(shape, seed) + tmean = Tensor(mean, mstype.float32) + output = net(tmean) + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform.py b/tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py similarity index 99% rename from tests/st/ops/ascend/test_aicpu_ops/test_uniform.py rename to tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py index cef50bdbc7..44a7798250 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_uniform.py +++ b/tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - import numpy as np import mindspore.context as context diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 8b2e7ab432..2eb3584c33 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -592,22 +592,22 @@ class LaplaceNet(nn.Cell): class GammaNet(nn.Cell): def __init__(self, shape=None, seed=0): super(GammaNet, self).__init__() - self.gamma = P.Gamma(seed=seed) self.shape = shape + self.seed = seed def construct(self, alpha, beta): - out = self.gamma(self.shape, alpha, beta) + out = C.gamma(self.shape, alpha, beta, self.seed) return out class PoissonNet(nn.Cell): def __init__(self, shape=None, seed=0): super(PoissonNet, self).__init__() - self.poisson = P.Poisson(seed=seed) self.shape = shape + self.seed = seed def construct(self, mean): - out = self.poisson(self.shape, mean) + out = C.poisson(self.shape, mean, self.seed) return out