diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index ca877852a8..1bf747c08a 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -17,6 +17,7 @@ from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ..distribution._utils.utils import CheckTensor +from ..distribution._utils.custom_ops import log_by_step, log1p_by_step, expm1_by_step from .bijector import Bijector class PowerTransform(Bijector): @@ -59,24 +60,12 @@ class PowerTransform(Bijector): self._power = power self.pow = P.Pow() self.exp = P.Exp() - self.log = P.Log() - self.log1p = self._log1p_by_step - self.expm1 = self._expm1_by_step + self.log = log_by_step + self.log1p = log1p_by_step + self.expm1 = expm1_by_step self.checktensor = CheckTensor() - def _log1p_by_step(self, x): - """ - Log1p ops on GPU device or when device_target == GPU. - """ - return self.log(x + 1.0) - - def _expm1_by_step(self, x): - """ - Expm1 ops on GPU device or when device_target == GPU. - """ - return self.exp(x) - 1.0 - @property def power(self): return self._power diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py index 44de3c68a0..276009b5fc 100644 --- a/mindspore/nn/probability/bijector/scalar_affine.py +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -16,6 +16,7 @@ from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator from ..distribution._utils.utils import cast_to_tensor, CheckTensor +from ..distribution._utils.custom_ops import log_by_step from .bijector import Bijector class ScalarAffine(Bijector): @@ -66,7 +67,7 @@ class ScalarAffine(Bijector): param=param) self.abs = P.Abs() - self.log = P.Log() + self.log = log_by_step self.checktensor = CheckTensor() diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index 070e483707..69ea2d8d05 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype from mindspore.nn.layer.activation import LogSigmoid from mindspore._checkparam import Validator as validator from ..distribution._utils.utils import cast_to_tensor, CheckTensor +from ..distribution._utils.custom_ops import log_by_step, expm1_by_step from .bijector import Bijector class Softplus(Bijector): @@ -60,12 +61,12 @@ class Softplus(Bijector): self.abs = P.Abs() self.exp = P.Exp() - self.expm1 = self._expm1_by_step + self.log = log_by_step + self.expm1 = expm1_by_step self.fill = P.Fill() self.greater = P.Greater() self.less = P.Less() self.log_sigmoid = LogSigmoid() - self.log = P.Log() self.logicalor = P.LogicalOr() self.select = P.Select() self.shape = P.Shape() @@ -76,12 +77,6 @@ class Softplus(Bijector): self.checktensor = CheckTensor() self.threshold = np.log(np.finfo(np.float32).eps) + 1 - def _expm1_by_step(self, x): - """ - Expm1 ops under GPU context. - """ - return self.exp(x) - 1.0 - def _softplus(self, x): too_small = self.less(x, self.threshold) too_large = self.greater(x, -self.threshold) diff --git a/mindspore/nn/probability/distribution/_utils/__init__.py b/mindspore/nn/probability/distribution/_utils/__init__.py index 68586ae4fd..07ebb623b5 100644 --- a/mindspore/nn/probability/distribution/_utils/__init__.py +++ b/mindspore/nn/probability/distribution/_utils/__init__.py @@ -16,6 +16,7 @@ Distribution operation utility functions. """ from .utils import * +from .custom_ops import * __all__ = [ 'convert_to_batch', @@ -27,4 +28,7 @@ __all__ = [ 'check_scalar_from_param', 'check_prob', 'check_type', + 'log_by_step', + 'log1p_by_step', + 'expm1_by_step', ] diff --git a/mindspore/nn/probability/distribution/_utils/custom_ops.py b/mindspore/nn/probability/distribution/_utils/custom_ops.py new file mode 100644 index 0000000000..430a311613 --- /dev/null +++ b/mindspore/nn/probability/distribution/_utils/custom_ops.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utitly functions to help distribution class.""" +import numpy as np +from mindspore.ops import operations as P + +def log_by_step(input_x): + """ + Log op on Ascend is calculated as log(abs(x)). + Fix this with putting negative values as nan. + """ + select = P.Select() + log = P.Log() + lessequal = P.LessEqual() + fill = P.Fill() + dtype = P.DType() + shape = P.Shape() + + nonpos_x = lessequal(input_x, 0.0) + log_x = log(input_x) + nan = fill(dtype(input_x), shape(input_x), np.nan) + result = select(nonpos_x, nan, log_x) + return result + +def log1p_by_step(x): + """ + Log1p ops on GPU device or when device_target == GPU. + """ + return log_by_step(x + 1.0) + +def expm1_by_step(input_x): + """ + Expm1 ops under GPU context. + """ + exp = P.Exp() + return exp(input_x) - 1.0 diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index e0170ffe4e..512a935ba8 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -19,6 +19,7 @@ from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import CheckTensor, CheckTuple +from ._utils.custom_ops import log_by_step class Bernoulli(Distribution): """ @@ -116,7 +117,7 @@ class Bernoulli(Distribution): self.exp = P.Exp() self.floor = P.Floor() self.fill = P.Fill() - self.log = P.Log() + self.log = log_by_step self.less = P.Less() self.shape = P.Shape() self.select = P.Select() diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index a8f38df16f..5a6ada38d5 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -21,6 +21,7 @@ from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ raise_none_error from ._utils.utils import CheckTensor, CheckTuple +from ._utils.custom_ops import log_by_step class Exponential(Distribution): """ @@ -119,7 +120,7 @@ class Exponential(Distribution): self.exp = P.Exp() self.fill = P.Fill() self.less = P.Less() - self.log = P.Log() + self.log = log_by_step self.select = P.Select() self.shape = P.Shape() self.sqrt = P.Sqrt() diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 1f16ef0240..9b1d866966 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -21,6 +21,7 @@ from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ raise_none_error from ._utils.utils import CheckTensor, CheckTuple +from ._utils.custom_ops import log_by_step class Geometric(Distribution): """ @@ -122,7 +123,7 @@ class Geometric(Distribution): self.floor = P.Floor() self.issubclass = P.IsSubClass() self.less = P.Less() - self.log = P.Log() + self.log = log_by_step self.pow = P.Pow() self.select = P.Select() self.shape = P.Shape() diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index ec72d5ea78..fc9a35908d 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -21,6 +21,7 @@ from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ raise_none_error from ._utils.utils import CheckTensor, CheckTuple +from ._utils.custom_ops import log_by_step, expm1_by_step class Normal(Distribution): """ @@ -119,9 +120,9 @@ class Normal(Distribution): self.const = P.ScalarToArray() self.erf = P.Erf() self.exp = P.Exp() - self.expm1 = self._expm1_by_step + self.expm1 = expm1_by_step self.fill = P.Fill() - self.log = P.Log() + self.log = log_by_step self.shape = P.Shape() self.sq = P.Square() self.sqrt = P.Sqrt() @@ -137,12 +138,6 @@ class Normal(Distribution): str_info = f'batch_shape = {self._broadcast_shape}' return str_info - def _expm1_by_step(self, x): - """ - Expm1 ops under GPU context. - """ - return self.exp(x) - 1.0 - def _check_param(self, mean, sd): """ Check availablity of distribution specific args mean and sd. diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 259f105d4e..8a34c67b05 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype import mindspore.nn as nn from .distribution import Distribution from ._utils.utils import check_type, raise_not_impl_error +from ._utils.custom_ops import log_by_step class TransformedDistribution(Distribution): """ @@ -56,7 +57,7 @@ class TransformedDistribution(Distribution): self._distribution = distribution self._is_linear_transformation = bijector.is_constant_jacobian self.exp = P.Exp() - self.log = P.Log() + self.log = log_by_step @property def bijector(self): diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index a37162a507..0d1b96c9e6 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -20,6 +20,7 @@ from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ raise_none_error from ._utils.utils import CheckTensor, CheckTuple +from ._utils.custom_ops import log_by_step class Uniform(Distribution): """ @@ -121,7 +122,7 @@ class Uniform(Distribution): self.fill = P.Fill() self.less = P.Less() self.lessequal = P.LessEqual() - self.log = P.Log() + self.log = log_by_step self.logicaland = P.LogicalAnd() self.select = P.Select() self.shape = P.Shape()