diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index c29aacbaef..5e44edf9b9 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -153,6 +153,12 @@ def enumerate_(x, start=0): return ret +def expand_tensor_as(x, y): + """Expand tensor""" + broadcast_to = P.BroadcastTo(shape_(y)) + return broadcast_to(x) + + def isinstance_(x, base_type): """Determine whether x is an instance of base_type.""" x_type = F.typeof(x) diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index ede5c59bae..2e5e3216b6 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -145,34 +145,35 @@ BuiltInTypeMap &GetMethodMap() { }}, {kObjectTypeTensorType, { - {"all", std::string("all_")}, // C.reduce_all - {"any", std::string("any_")}, // C.reduce_any - {"__add__", std::string("add")}, // C.add - {"__sub__", std::string("sub")}, // C.sub - {"__mul__", std::string("mul")}, // C.mul - {"__truediv__", std::string("truediv")}, // C.truediv - {"__floordiv__", std::string("floordiv")}, // C.floordiv - {"__mod__", std::string("mod")}, // C.mod - {"__pow__", std::string("pow_")}, // C.pow - {"__floor__", std::string("array_floor")}, // C.array_floor - {"__trunc__", std::string("array_trunc")}, // C.array_trunc - {"__pos__", std::string("array_uadd")}, // C.array_uadd - {"__neg__", std::string("array_usub")}, // C.array_usub - {"__eq__", std::string("eq")}, // C.eq - {"__ne__", std::string("ne")}, // C.ne - {"__lt__", std::string("lt")}, // C.lt - {"__gt__", std::string("gt")}, // C.gt - {"__le__", std::string("le")}, // C.le - {"__ge__", std::string("ge")}, // C.ge - {"__matmul__", prim::kPrimDot}, // P.dot, - {"__len__", prim::kPrimArrayLen}, // P.array_len, - {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, - {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, - {"__ms_iter__", std::string("array_iter")}, // C.array_iter - {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, - {"transpose", std::string("transpose")}, // P.transpose - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + {"all", std::string("all_")}, // C.reduce_all + {"any", std::string("any_")}, // C.reduce_any + {"__add__", std::string("add")}, // C.add + {"__sub__", std::string("sub")}, // C.sub + {"__mul__", std::string("mul")}, // C.mul + {"__truediv__", std::string("truediv")}, // C.truediv + {"__floordiv__", std::string("floordiv")}, // C.floordiv + {"__mod__", std::string("mod")}, // C.mod + {"__pow__", std::string("pow_")}, // C.pow + {"__floor__", std::string("array_floor")}, // C.array_floor + {"__trunc__", std::string("array_trunc")}, // C.array_trunc + {"__pos__", std::string("array_uadd")}, // C.array_uadd + {"__neg__", std::string("array_usub")}, // C.array_usub + {"__eq__", std::string("eq")}, // C.eq + {"__ne__", std::string("ne")}, // C.ne + {"__lt__", std::string("lt")}, // C.lt + {"__gt__", std::string("gt")}, // C.gt + {"__le__", std::string("le")}, // C.le + {"__ge__", std::string("ge")}, // C.ge + {"expand_as", std::string("expand_tensor_as")}, // C.expand_as + {"__matmul__", prim::kPrimDot}, // P.dot, + {"__len__", prim::kPrimArrayLen}, // P.array_len, + {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, + {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, + {"__ms_iter__", std::string("array_iter")}, // C.array_iter + {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, + {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"transpose", std::string("transpose")}, // P.transpose + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool }}, {kObjectTypeJTagged, {}}, {kObjectTypeSymbolicKeyType, {}}, diff --git a/tests/ut/python/pipeline/parse/test_expand_as.py b/tests/ut/python/pipeline/parse/test_expand_as.py new file mode 100644 index 0000000000..36583e065f --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_expand_as.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test expand_as""" +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_expand_as(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.t1 = Tensor([1, 2, 3]) + self.t2 = Tensor([[1, 1, 1], [1, 1, 1]]) + + def construct(self): + return self.t1.expand_as(self.t2) + + net = Net() + net() + + +def test_expand_as_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.t1 = Tensor([1, 2, 3]) + + def construct(self, x): + return self.t1.expand_as(x) + + net = Net() + net(Tensor([[1, 1, 1], [1, 1, 1]])) + + +def test_expand_tensor_as_parameter_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.t2 = Tensor([[1, 1, 1], [1, 1, 1]]) + + def construct(self, x): + return x.expand_as(self.t2) + + net = Net() + net(Tensor([1, 2, 3]))