From e1dba1337c66ed238817690a1a060f7125cb5750 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Mon, 16 Nov 2020 20:40:07 +0800 Subject: [PATCH] Add nn.Tril function --- mindspore/nn/layer/basic.py | 76 +++++++++++++++++++++- tests/ut/python/nn/test_tril.py | 108 ++++++++++++++++++++++++++++++++ tests/ut/python/nn/test_triu.py | 108 ++++++++++++++++++++++++++++++++ 3 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 tests/ut/python/nn/test_tril.py create mode 100644 tests/ut/python/nn/test_triu.py diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5e9283fed7..2d947d3fb2 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -35,7 +35,7 @@ from .activation import get_activation __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', - 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] + 'Tril', 'Triu', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] class Dropout(Cell): @@ -547,6 +547,80 @@ class Unfold(Cell): return result +@constexpr +def tril(x_shape, x_dtype, k): + Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril") + Validator.check_is_int(k, "k value", "tril") + mask = np.tril(np.ones(x_shape), k) + return Tensor(mask, x_dtype) + + +class Tril(Cell): + """ + Returns a tensor with elements above the kth diagonal zeroed. + + Inputs: + - **x** (Tensor) - The input tensor. + - **k** (Int) - The index of diagonal. Default: 0 + + Outputs: + Tensor, has the same type as input `x`. + + Examples: + >>> x = Tensor(np.array([[1, 2], [3, 4]])) + >>> tril = nn.Tril() + >>> result = tril(x) + >>> print(result) + [[1 0] + [3 4]] + """ + def __init__(self): + super(Tril, self).__init__() + self.dtype = P.DType() + self.mul = P.Mul() + + def construct(self, x, k=0): + assist = tril(x.shape, self.dtype(x), k) + return self.mul(x, assist) + + +@constexpr +def triu(x_shape, x_dtype, k): + Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu") + Validator.check_is_int(k, "k value", "triu") + mask = np.triu(np.ones(x_shape), k) + return Tensor(mask, x_dtype) + + +class Triu(Cell): + """ + Returns a tensor with elements below the kth diagonal zeroed. + + Inputs: + - **x** (Tensor) - The input tensor. + - **k** (Int) - The index of diagonal. Default: 0 + + Outputs: + Tensor, has the same type as input `x`. + + Examples: + >>> x = Tensor(np.array([[1, 2], [3, 4]])) + >>> tril = nn.Tril() + >>> result = tril(x) + >>> print(result) + [[1 2] + [0 4]] + """ + def __init__(self): + super(Triu, self).__init__() + self.dtype = P.DType() + self.mul = P.Mul() + + def construct(self, x, k=0): + assist = triu(x.shape, self.dtype(x), k) + return self.mul(x, assist) + + @constexpr def _get_matrix_diag_assist(x_shape, x_dtype): Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") diff --git a/tests/ut/python/nn/test_tril.py b/tests/ut/python/nn/test_tril.py new file mode 100644 index 0000000000..5bcc489a04 --- /dev/null +++ b/tests/ut/python/nn/test_tril.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test nn.Tril() +""" +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_tril(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def construct(self): + tril = nn.Tril() + return tril(self.value, 0) + + net = Net() + out = net() + assert np.sum(out.asnumpy()) == 34 + + +def test_tril_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def construct(self): + tril = nn.Tril() + return tril(self.value, 1) + + net = Net() + out = net() + assert np.sum(out.asnumpy()) == 42 + + +def test_tril_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def construct(self): + tril = nn.Tril() + return tril(self.value, -1) + + net = Net() + out = net() + assert np.sum(out.asnumpy()) == 19 + + +def test_tril_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + tril = nn.Tril() + return tril(x, 0) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + +def test_tril_parameter_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + tril = nn.Tril() + return tril(x, 1) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + +def test_tril_parameter_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + tril = nn.Tril() + return tril(x, -1) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) diff --git a/tests/ut/python/nn/test_triu.py b/tests/ut/python/nn/test_triu.py new file mode 100644 index 0000000000..9e4c0d8f03 --- /dev/null +++ b/tests/ut/python/nn/test_triu.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test nn.Triu() +""" +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_triu(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def construct(self): + triu = nn.Triu() + return triu(self.value, 0) + + net = Net() + out = net() + assert np.sum(out.asnumpy()) == 26 + + +def test_triu_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def construct(self): + triu = nn.Triu() + return triu(self.value, 1) + + net = Net() + out = net() + assert np.sum(out.asnumpy()) == 11 + + +def test_triu_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def construct(self): + triu = nn.Triu() + return triu(self.value, -1) + + net = Net() + out = net() + assert np.sum(out.asnumpy()) == 38 + + +def test_triu_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + triu = nn.Triu() + return triu(x, 0) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + +def test_triu_parameter_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + triu = nn.Triu() + return triu(x, 1) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + +def test_triu_parameter_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + triu = nn.Triu() + return triu(x, -1) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))