diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 8484fdeddc..604fa54494 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -218,6 +218,11 @@ def raise_not_impl_error(name): raise ValueError( f"{name} function should be implemented for non-linear transformation") +@constexpr +def raise_not_implemented_util(func_name, obj, *args, **kwargs): + raise NotImplementedError( + f"{func_name} is not implemented for {obj} distribution.") + @constexpr def check_distribution_name(name, expected_name): diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index ca6cabea43..0547810c4e 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -19,7 +19,8 @@ from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common import get_seed -from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device +from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ + raise_not_implemented_util from ._utils.utils import CheckTuple, CheckTensor from ._utils.custom_ops import broadcast_to, exp_generic, log_generic @@ -245,6 +246,8 @@ class Distribution(Cell): self._call_prob = self._prob elif hasattr(self, '_log_prob'): self._call_prob = self._calc_prob_from_log_prob + else: + self._call_prob = self._raise_not_implemented_error('prob') def _set_sd(self): """ @@ -254,6 +257,8 @@ class Distribution(Cell): self._call_sd = self._sd elif hasattr(self, '_var'): self._call_sd = self._calc_sd_from_var + else: + self._call_sd = self._raise_not_implemented_error('sd') def _set_var(self): """ @@ -263,6 +268,8 @@ class Distribution(Cell): self._call_var = self._var elif hasattr(self, '_sd'): self._call_var = self._calc_var_from_sd + else: + self._call_var = self._raise_not_implemented_error('var') def _set_log_prob(self): """ @@ -272,6 +279,8 @@ class Distribution(Cell): self._call_log_prob = self._log_prob elif hasattr(self, '_prob'): self._call_log_prob = self._calc_log_prob_from_prob + else: + self._call_log_prob = self._raise_not_implemented_error('log_prob') def _set_cdf(self): """ @@ -286,13 +295,18 @@ class Distribution(Cell): self._call_cdf = self._calc_cdf_from_survival elif hasattr(self, '_log_survival'): self._call_cdf = self._calc_cdf_from_log_survival + else: + self._call_cdf = self._raise_not_implemented_error('cdf') def _set_survival(self): """ Set survival function based on the availability of _survival function and `_log_survival` and `_call_cdf`. """ - if hasattr(self, '_survival_function'): + if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \ + hasattr(self, '_cdf') or hasattr(self, '_log_cdf')): + self._call_survival = self._raise_not_implemented_error('survival_function') + elif hasattr(self, '_survival_function'): self._call_survival = self._survival_function elif hasattr(self, '_log_survival'): self._call_survival = self._calc_survival_from_log_survival @@ -303,7 +317,10 @@ class Distribution(Cell): """ Set log cdf based on the availability of `_log_cdf` and `_call_cdf`. """ - if hasattr(self, '_log_cdf'): + if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \ + hasattr(self, '_survival_function') or hasattr(self, '_log_survival')): + self._call_log_cdf = self._raise_not_implemented_error('log_cdf') + elif hasattr(self, '_log_cdf'): self._call_log_cdf = self._log_cdf elif hasattr(self, '_call_cdf'): self._call_log_cdf = self._calc_log_cdf_from_call_cdf @@ -312,7 +329,10 @@ class Distribution(Cell): """ Set log survival based on the availability of `_log_survival` and `_call_survival`. """ - if hasattr(self, '_log_survival'): + if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \ + hasattr(self, '_log_cdf') or hasattr(self, '_cdf')): + self._call_log_survival = self._raise_not_implemented_error('log_cdf') + elif hasattr(self, '_log_survival'): self._call_log_survival = self._log_survival elif hasattr(self, '_call_survival'): self._call_log_survival = self._calc_log_survival_from_call_survival @@ -323,6 +343,14 @@ class Distribution(Cell): """ if hasattr(self, '_cross_entropy'): self._call_cross_entropy = self._cross_entropy + else: + self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') + + def _raise_not_implemented_error(self, func_name): + name = self.name + def raise_error(*args, **kwargs): + return raise_not_implemented_util(func_name, name, *args, **kwargs) + return raise_error def log_prob(self, value, *args, **kwargs): """ @@ -495,6 +523,9 @@ class Distribution(Cell): """ return self.log_base(self._call_survival(value, *args, **kwargs)) + def _kl_loss(self, *args, **kwargs): + return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs) + def kl_loss(self, dist, *args, **kwargs): """ Evaluate the KL divergence, i.e. KL(a||b). @@ -510,6 +541,9 @@ class Distribution(Cell): """ return self._kl_loss(dist, *args, **kwargs) + def _mean(self, *args, **kwargs): + return raise_not_implemented_util('mean', self.name, *args, **kwargs) + def mean(self, *args, **kwargs): """ Evaluate the mean. @@ -524,6 +558,9 @@ class Distribution(Cell): """ return self._mean(*args, **kwargs) + def _mode(self, *args, **kwargs): + return raise_not_implemented_util('mode', self.name, *args, **kwargs) + def mode(self, *args, **kwargs): """ Evaluate the mode. @@ -584,6 +621,9 @@ class Distribution(Cell): """ return self.sq_base(self._sd(*args, **kwargs)) + def _entropy(self, *args, **kwargs): + return raise_not_implemented_util('entropy', self.name, *args, **kwargs) + def entropy(self, *args, **kwargs): """ Evaluate the entropy. @@ -622,6 +662,9 @@ class Distribution(Cell): """ return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) + def _sample(self, *args, **kwargs): + return raise_not_implemented_util('sample', self.name, *args, **kwargs) + def sample(self, *args, **kwargs): """ Sampling function. @@ -680,4 +723,4 @@ class Distribution(Cell): return self._call_cross_entropy(*args, **kwargs) if name == 'sample': return self._sample(*args, **kwargs) - return None + return raise_not_implemented_util(name, self.name, *args, **kwargs) diff --git a/tests/ut/python/nn/probability/distribution/test_distribution.py b/tests/ut/python/nn/probability/distribution/test_distribution.py new file mode 100644 index 0000000000..bea14a4d65 --- /dev/null +++ b/tests/ut/python/nn/probability/distribution/test_distribution.py @@ -0,0 +1,102 @@ +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution. +""" +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype as mstype +from mindspore import Tensor +from mindspore import context + +func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf', + 'survival_function', 'log_survival', + 'sd', 'var', 'mode', 'mean', + 'entropy', 'kl_loss', 'cross_entropy', + 'sample'] + +class MyExponential(msd.Distribution): + """ + Test distirbution class: no function is implemented. + """ + def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"): + param = dict(locals()) + param['param_dict'] = {'rate': rate} + super(MyExponential, self).__init__(seed, dtype, name, param) + +class Net(nn.Cell): + """ + Test Net: function called through construct. + """ + def __init__(self, func_name): + super(Net, self).__init__() + self.dist = MyExponential() + self.name = func_name + + def construct(self, *args, **kwargs): + return self.dist(self.name, *args, **kwargs) + + +def test_raise_not_implemented_error_construct(): + """ + test raise not implemented error in pynative mode. + """ + value = Tensor([0.2], dtype=mstype.float32) + for func_name in func_name_list: + with pytest.raises(NotImplementedError): + net = Net(func_name) + net(value) + +def test_raise_not_implemented_error_construct_graph_mode(): + """ + test raise not implemented error in graph mode. + """ + context.set_context(mode=context.GRAPH_MODE) + value = Tensor([0.2], dtype=mstype.float32) + for func_name in func_name_list: + with pytest.raises(NotImplementedError): + net = Net(func_name) + net(value) + +class Net1(nn.Cell): + """ + Test Net: function called directly. + """ + def __init__(self, func_name): + super(Net1, self).__init__() + self.dist = MyExponential() + self.func = getattr(self.dist, func_name) + + def construct(self, *args, **kwargs): + return self.func(*args, **kwargs) + +def test_raise_not_implemented_error(): + """ + test raise not implemented error in pynative mode. + """ + value = Tensor([0.2], dtype=mstype.float32) + for func_name in func_name_list: + with pytest.raises(NotImplementedError): + net = Net1(func_name) + net(value) + +def test_raise_not_implemented_error_graph_mode(): + """ + test raise not implemented error in graph mode. + """ + context.set_context(mode=context.GRAPH_MODE) + value = Tensor([0.2], dtype=mstype.float32) + for func_name in func_name_list: + with pytest.raises(NotImplementedError): + net = Net1(func_name) + net(value)