|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from mindspore.common.api import ms_function
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.ops import Primitive
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
@ -24,6 +25,7 @@ from ...ut_filter import non_graph_engine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_add = P.TensorAdd()
|
|
|
|
|
op_add = P.AddN()
|
|
|
|
|
scala_add = Primitive('scalar_add')
|
|
|
|
|
add = C.MultitypeFuncGraph('add')
|
|
|
|
|
|
|
|
|
@ -50,5 +52,14 @@ def test_multitype_tensor():
|
|
|
|
|
mainf(tensor1, tensor2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@non_graph_engine
|
|
|
|
|
def test_multitype_tuple():
|
|
|
|
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
|
|
|
|
params1 = Parameter(tensor1, name="params1")
|
|
|
|
|
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
|
|
|
|
output = op_add((params1, tensor2))
|
|
|
|
|
assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_multitype_scalar():
|
|
|
|
|
mainf(1, 2)
|
|
|
|
|