modify map to C.Map()

pull/2548/head
huangdongrun 5 years ago
parent cf0eca5608
commit 9d3c9c69fe

@ -111,7 +111,7 @@ convert_object_map = {
# system function
T.len: M.ms_len,
T.bool: M.bool_,
T.map: C.HyperMap(),
T.map: C.Map(),
T.partial: F.partial,
T.zip: C.zip_operation,
T.print: F.print_,

@ -181,6 +181,9 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
}
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
if (arg_pairs.size() < 1) {
MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
}
bool found = false;
TypeId id = kObjectTypeEnd;
std::pair<AnfNodePtr, TypePtr> pair;

@ -18,6 +18,7 @@ import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.common.api import _executor
@ -93,3 +94,25 @@ def test_compile_unspported():
net = unsupported_method_net()
with pytest.raises(RuntimeError):
_executor.compile(net, input_me)
def test_parser_map_0002():
class NetMap0002(nn.Cell):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
self.hypermap = C.Map()
def mul(self, x=2, y=4):
return x * y
def construct(self, x):
if map(self.mul) == 8:
x = self.relu(x)
return x
input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me_x = Tensor(input_np_x)
net = NetMap0002()
with pytest.raises(TypeError):
net(input_me_x)

Loading…
Cancel
Save