From 9d3c9c69fee3fd969b6f1ca163841f6b6f1d1960 Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Wed, 24 Jun 2020 11:40:46 +0800 Subject: [PATCH] modify map to C.Map() --- mindspore/_extends/parse/resources.py | 2 +- mindspore/ccsrc/operator/composite/map.cc | 3 +++ .../ut/python/pipeline/parse/test_fix_bug.py | 23 +++++++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 2ae8b7172f..eb89c965df 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -111,7 +111,7 @@ convert_object_map = { # system function T.len: M.ms_len, T.bool: M.bool_, - T.map: C.HyperMap(), + T.map: C.Map(), T.partial: F.partial, T.zip: C.zip_operation, T.print: F.print_, diff --git a/mindspore/ccsrc/operator/composite/map.cc b/mindspore/ccsrc/operator/composite/map.cc index a054da5f4d..6062f0f5af 100644 --- a/mindspore/ccsrc/operator/composite/map.cc +++ b/mindspore/ccsrc/operator/composite/map.cc @@ -181,6 +181,9 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGrap } AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + if (arg_pairs.size() < 1) { + MS_EXCEPTION(TypeError) << "map() must have at least two arguments"; + } bool found = false; TypeId id = kObjectTypeEnd; std::pair pair; diff --git a/tests/ut/python/pipeline/parse/test_fix_bug.py b/tests/ut/python/pipeline/parse/test_fix_bug.py index 5bf7db3798..9b013f95a4 100644 --- a/tests/ut/python/pipeline/parse/test_fix_bug.py +++ b/tests/ut/python/pipeline/parse/test_fix_bug.py @@ -18,6 +18,7 @@ import pytest import mindspore.nn as nn from mindspore import Tensor +from mindspore.ops import composite as C from mindspore.common.api import _executor @@ -93,3 +94,25 @@ def test_compile_unspported(): net = unsupported_method_net() with pytest.raises(RuntimeError): _executor.compile(net, input_me) + + +def test_parser_map_0002(): + class NetMap0002(nn.Cell): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + self.hypermap = C.Map() + + def mul(self, x=2, y=4): + return x * y + + def construct(self, x): + if map(self.mul) == 8: + x = self.relu(x) + return x + input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_x = Tensor(input_np_x) + + net = NetMap0002() + with pytest.raises(TypeError): + net(input_me_x)