|
|
|
@ -341,22 +341,42 @@ class GradOperation(GradOperation_):
|
|
|
|
|
|
|
|
|
|
class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
|
|
"""
|
|
|
|
|
Generate multiply graph.
|
|
|
|
|
Generate overloaded functions.
|
|
|
|
|
|
|
|
|
|
MultitypeFuncGraph is a class used to generate graphs for function with different type as input.
|
|
|
|
|
MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs.
|
|
|
|
|
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
|
|
|
|
|
for the function to be registed. And the object can be called with different type of inputs,
|
|
|
|
|
and work with `HyperMap` and `Map`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): Operator name.
|
|
|
|
|
read_value (bool): If the registered function not need to set value on Parameter,
|
|
|
|
|
and all inputs will pass by value. Set `read_value` to True. Default: False.
|
|
|
|
|
and all inputs will pass by value, set `read_value` to True. Default: False.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: Cannot find matching fn for the given args.
|
|
|
|
|
ValueError: Cannot find matching functions for the given args.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> # `add` is a metagraph object which will add two objects according to
|
|
|
|
|
>>> # input type using ".register" decorator.
|
|
|
|
|
>>> from mindspore import Tensor
|
|
|
|
|
>>> from mindspore.ops import Primitive, operations as P
|
|
|
|
|
>>> from mindspore import dtype as mstype
|
|
|
|
|
>>>
|
|
|
|
|
>>> scala_add = Primitive('scala_add')
|
|
|
|
|
>>> tensor_add = P.TensorAdd()
|
|
|
|
|
>>>
|
|
|
|
|
>>> add = MultitypeFuncGraph('add')
|
|
|
|
|
>>> @add.register("Number", "Number")
|
|
|
|
|
... def add_scala(x, y):
|
|
|
|
|
... return scala_add(x, y)
|
|
|
|
|
>>> @add.register("Tensor", "Tensor")
|
|
|
|
|
... def add_tensor(x, y):
|
|
|
|
|
... return tensor_add(x, y)
|
|
|
|
|
>>> add(1, 2)
|
|
|
|
|
3
|
|
|
|
|
>>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
|
|
|
|
|
Tensor(shape=[], dtype=Float32, 3)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, read_value=False):
|
|
|
|
@ -378,9 +398,25 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
|
|
raise ValueError("Cannot find fn match given args.")
|
|
|
|
|
|
|
|
|
|
def register(self, *type_names):
|
|
|
|
|
"""Register a function for the given type string."""
|
|
|
|
|
"""
|
|
|
|
|
Register a function for the given type string.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
decorator, a decorator to register the function to run, when called under the
|
|
|
|
|
types described in `type_names`.
|
|
|
|
|
"""
|
|
|
|
|
def deco(fn):
|
|
|
|
|
types = tuple(map(mstype.typing.str_to_type, type_names))
|
|
|
|
|
def convert_type(type_input):
|
|
|
|
|
if isinstance(type_input, str):
|
|
|
|
|
return mstype.typing.str_to_type(type_input)
|
|
|
|
|
if not isinstance(type_input, mstype.Type):
|
|
|
|
|
raise TypeError(f"MultitypeFuncGraph register only support str or {mstype.Type}")
|
|
|
|
|
return type_input
|
|
|
|
|
|
|
|
|
|
types = tuple(map(convert_type, type_names))
|
|
|
|
|
self.register_fn(type_names, fn)
|
|
|
|
|
self.entries.append((types, fn))
|
|
|
|
|
return fn
|
|
|
|
@ -391,11 +427,12 @@ class HyperMap(HyperMap_):
|
|
|
|
|
"""
|
|
|
|
|
Hypermap will apply the set operation on input sequences.
|
|
|
|
|
|
|
|
|
|
Which will apply the operations of every elements of the sequence.
|
|
|
|
|
Apply the operations to every elements of the sequence or nested sequence. Different
|
|
|
|
|
from `Map`, the `HyperMap` supports to apply on nested structure.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
|
|
|
|
the operations should be putted in the first input of the instance.
|
|
|
|
|
the operations should be put in the first input of the instance.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
|
|
|
@ -405,8 +442,28 @@ class HyperMap(HyperMap_):
|
|
|
|
|
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
sequence, the output will be same type and same length of sequence from input and the value of each element
|
|
|
|
|
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
|
|
|
|
|
Sequence or nested sequence, the sequence of output after applying the function.
|
|
|
|
|
e.g. `operation(args[0][i], args[1][i])`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore import dtype as mstype
|
|
|
|
|
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
|
|
|
|
|
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
|
|
|
|
|
>>> # square all the tensor in the nested list
|
|
|
|
|
>>>
|
|
|
|
|
>>> square = MultitypeFuncGraph('square')
|
|
|
|
|
>>> @square.register("Tensor")
|
|
|
|
|
... def square_tensor(x):
|
|
|
|
|
... return F.square(x)
|
|
|
|
|
>>>
|
|
|
|
|
>>> common_map = HyperMap()
|
|
|
|
|
>>> common_map(square, nest_tensor_list)
|
|
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
|
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
|
|
|
|
>>> square_map = HyperMap(square)
|
|
|
|
|
>>> square_map(nest_tensor_list)
|
|
|
|
|
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
|
|
|
|
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, ops=None):
|
|
|
|
@ -434,11 +491,11 @@ class Map(Map_):
|
|
|
|
|
"""
|
|
|
|
|
Map will apply the set operation on input sequences.
|
|
|
|
|
|
|
|
|
|
Which will apply the operations of every elements of the sequence.
|
|
|
|
|
Apply the operations to every elements of the sequence.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
|
|
|
|
the operations should be putted in the first input of the instance.
|
|
|
|
|
the operations should be put in the first input of the instance. Default: None
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
|
|
|
@ -448,8 +505,24 @@ class Map(Map_):
|
|
|
|
|
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
sequence, the output will be same type and same length of sequence from input and the value of each element
|
|
|
|
|
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
|
|
|
|
|
Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore import dtype as mstype
|
|
|
|
|
>>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32))
|
|
|
|
|
>>> # square all the tensor in the list
|
|
|
|
|
>>>
|
|
|
|
|
>>> square = MultitypeFuncGraph('square')
|
|
|
|
|
>>> @square.register("Tensor")
|
|
|
|
|
>>> def square_tensor(x):
|
|
|
|
|
... return F.square(x)
|
|
|
|
|
>>>
|
|
|
|
|
>>> common_map = Map()
|
|
|
|
|
>>> common_map(square, tensor_list)
|
|
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
|
|
|
|
>>> square_map = Map(square)
|
|
|
|
|
>>> square_map(tensor_list)
|
|
|
|
|
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, ops=None):
|
|
|
|
|