diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 27c4112d5f..7f80f65305 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -28,6 +28,7 @@ __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_lik trans = P.Transpose() shape_ = P.Shape() +reshape_ = P.Reshape() dtype_ = P.DType() def all_(x, axis=(), keep_dims=False): @@ -162,6 +163,12 @@ def expand_tensor_as(x, y): return broadcast_to(x) +def view(x, *shape): + """Reshape tensor, if shape is -1, reshape tensor into one dimension""" + shape = check_view_shape(shape) + return reshape_(x, shape) + + def isinstance_(x, base_type): """Determine whether x is an instance of base_type.""" x_type = F.typeof(x) @@ -232,6 +239,18 @@ def const_tensor_to_bool(x): raise ValueError("The truth value of an array with several elements is ambiguous.") +@constexpr +def check_view_shape(x): + """Check view function input shape""" + if not x: + raise ValueError("The shape variable should not be empty") + if isinstance(x[0], tuple): + if len(x) != 1: + raise ValueError(f"Only one tuple is needed, but got {x}") + x = x[0] + return x + + def tensor_bool(x): """tensor as conditon, if is constant, return immediate bool value""" is_cond = check_is_tensor_bool_cond(F.shape(x)) diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index e970c0c2e9..03941f86ff 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -167,6 +167,7 @@ BuiltInTypeMap &GetMethodMap() { {"__le__", std::string("le")}, // C.le {"__ge__", std::string("ge")}, // C.ge {"expand_as", std::string("expand_tensor_as")}, // C.expand_as + {"view", std::string("view")}, // C.view {"__matmul__", prim::kPrimDot}, // P.dot, {"__len__", prim::kPrimArrayLen}, // P.array_len, {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, diff --git a/tests/ut/python/pipeline/parse/test_view.py b/tests/ut/python/pipeline/parse/test_view.py new file mode 100644 index 0000000000..13e085d93c --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_view.py @@ -0,0 +1,127 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test view""" +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_view(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]]) + + def construct(self): + return self.value.view(-1) + + net = Net() + net() + + +def test_view_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]]) + + def construct(self): + return self.value.view((3, 2)) + + net = Net() + net() + + +def test_view_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]]) + + def construct(self): + return self.value.view(3, 2) + + net = Net() + net() + + +def test_view_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return x.view(-1) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6]])) + + +def test_view_parameter_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return x.view((3, 2)) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6]])) + + +def test_view_parameter_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return x.view(3, 2) + + net = Net() + net(Tensor([[1, 2, 3], [4, 5, 6]])) + + +def test_view_shape_error(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]]) + + def construct(self): + return self.value.view() + + net = Net() + with pytest.raises(ValueError) as ex: + net() + assert "The shape variable should not be empty" in str(ex.value) + + +def test_view_shape_error_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]]) + + def construct(self): + return self.value.view((2, 3), (4, 5)) + + net = Net() + with pytest.raises(ValueError) as ex: + net() + assert "Only one tuple is needed" in str(ex.value)