|
|
|
@ -75,7 +75,7 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
|
|
|
|
|
|
|
|
return nonzero_num
|
|
|
|
|
|
|
|
|
|
# TensorDot
|
|
|
|
|
# tensor dot
|
|
|
|
|
@constexpr
|
|
|
|
|
def _int_to_tuple_conv(axes):
|
|
|
|
|
"""
|
|
|
|
@ -92,7 +92,7 @@ def _check_axes(axes):
|
|
|
|
|
"""
|
|
|
|
|
Check for validity and type of axes passed to function.
|
|
|
|
|
"""
|
|
|
|
|
validator.check_value_type('axes', axes, [int, tuple, list], "TensorDot")
|
|
|
|
|
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
|
|
|
|
|
if not isinstance(axes, int):
|
|
|
|
|
axes = list(axes) # to avoid immutability issues
|
|
|
|
|
if len(axes) != 2:
|
|
|
|
@ -156,7 +156,7 @@ def _calc_new_shape(shape, axes, position=0):
|
|
|
|
|
return new_shape, transpose_perm, free_dims
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def TensorDot(x1, x2, axes):
|
|
|
|
|
def tensor_dot(x1, x2, axes):
|
|
|
|
|
"""
|
|
|
|
|
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
|
|
|
|
|
|
|
|
|
@ -171,8 +171,8 @@ def TensorDot(x1, x2, axes):
|
|
|
|
|
axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b`
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x1** (Tensor) - First tensor in TensorDot op with datatype float16 or float32
|
|
|
|
|
- **x2** (Tensor) - Second tensor in TensorDot op with datatype float16 or float32
|
|
|
|
|
- **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32
|
|
|
|
|
- **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32
|
|
|
|
|
- **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or
|
|
|
|
|
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
|
|
|
|
|
automatically picks up first N dims from `a` input shape and last N dims from `b` input shape.
|
|
|
|
@ -184,7 +184,7 @@ def TensorDot(x1, x2, axes):
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
|
|
|
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
|
|
|
|
|
>>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2)))
|
|
|
|
|
>>> output = C.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[2,2,2],
|
|
|
|
|
[2,2,2],
|
|
|
|
@ -206,7 +206,7 @@ def TensorDot(x1, x2, axes):
|
|
|
|
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
|
|
|
|
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
|
|
|
|
|
output_shape = x1_ret + x2_ret # combine free axes from both inputs
|
|
|
|
|
# run TensorDot op
|
|
|
|
|
# run tensor_dot op
|
|
|
|
|
x1_transposed = transpose_op(x1, x1_transpose_fwd)
|
|
|
|
|
x2_transposed = transpose_op(x2, x2_transpose_fwd)
|
|
|
|
|
x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
|
|
|
|
|