|
|
@ -384,9 +384,11 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
|
>>> @add.register("Tensor", "Tensor")
|
|
|
|
>>> @add.register("Tensor", "Tensor")
|
|
|
|
... def add_tensor(x, y):
|
|
|
|
... def add_tensor(x, y):
|
|
|
|
... return tensor_add(x, y)
|
|
|
|
... return tensor_add(x, y)
|
|
|
|
>>> add(1, 2)
|
|
|
|
>>> ourput = add(1, 2)
|
|
|
|
|
|
|
|
>>> print(output)
|
|
|
|
3
|
|
|
|
3
|
|
|
|
>>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
|
|
|
|
>>> output = add(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
|
|
|
|
|
|
|
|
>>> print(output)
|
|
|
|
Tensor(shape=[], dtype=Float32, 3)
|
|
|
|
Tensor(shape=[], dtype=Float32, 3)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
@ -470,11 +472,13 @@ class HyperMap(HyperMap_):
|
|
|
|
... return F.square(x)
|
|
|
|
... return F.square(x)
|
|
|
|
>>>
|
|
|
|
>>>
|
|
|
|
>>> common_map = HyperMap()
|
|
|
|
>>> common_map = HyperMap()
|
|
|
|
>>> common_map(square, nest_tensor_list)
|
|
|
|
>>> output = common_map(square, nest_tensor_list)
|
|
|
|
|
|
|
|
>>> print(output)
|
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
|
|
|
>>> square_map = HyperMap(square)
|
|
|
|
>>> square_map = HyperMap(square)
|
|
|
|
>>> square_map(nest_tensor_list)
|
|
|
|
>>> output = square_map(nest_tensor_list)
|
|
|
|
|
|
|
|
>>> print(output)
|
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -531,10 +535,12 @@ class Map(Map_):
|
|
|
|
... return F.square(x)
|
|
|
|
... return F.square(x)
|
|
|
|
>>>
|
|
|
|
>>>
|
|
|
|
>>> common_map = Map()
|
|
|
|
>>> common_map = Map()
|
|
|
|
>>> common_map(square, tensor_list)
|
|
|
|
>>> output = common_map(square, tensor_list)
|
|
|
|
|
|
|
|
>>> print(output)
|
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
|
|
|
>>> square_map = Map(square)
|
|
|
|
>>> square_map = Map(square)
|
|
|
|
>>> square_map(tensor_list)
|
|
|
|
>>> output = square_map(tensor_list)
|
|
|
|
|
|
|
|
>>> print(output)
|
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|