|
|
|
@ -336,6 +336,17 @@ def _get_batch_size(x1_shape, x2_shape):
|
|
|
|
|
return x1_shape[0], x2_shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _typecheck_input_batch_dot(x1_type, x2_type):
|
|
|
|
|
"""
|
|
|
|
|
Check input tensor types to be valid and confirm they are the same type for batch dot ops.
|
|
|
|
|
"""
|
|
|
|
|
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
|
|
|
|
|
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
|
|
|
|
|
if x1_type != x2_type:
|
|
|
|
|
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
|
|
|
|
"""
|
|
|
|
@ -419,15 +430,29 @@ def batch_dot(x1, x2, axes=None):
|
|
|
|
|
Computation of batch dot product between samples in two tensors containing batch dims.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
|
|
|
|
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
|
|
|
|
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
|
|
|
|
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
|
|
|
|
|
be same as x1's.
|
|
|
|
|
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
|
|
|
|
|
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
|
|
|
|
|
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, batch dot product of x1 and x2.
|
|
|
|
|
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
|
|
|
|
|
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
output = x1[batch, :] * x2[batch, :]
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If shapes of x1 and x2 are not the same.
|
|
|
|
|
ValueError: If rank of x1 or x2 less than 2.
|
|
|
|
|
ValueError: If batch dim used in axes.
|
|
|
|
|
ValueError: If dtype of x1 or x2 is not float32.
|
|
|
|
|
ValueError: If len(axes) less than 2.
|
|
|
|
|
ValueError: If axes is not one of those: None, int, (int, int).
|
|
|
|
|
ValueError: If axes value is too high for dimensions of input arrays.
|
|
|
|
|
ValueError: If batch size of x1 and x2 are not the same.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
@ -458,7 +483,7 @@ def batch_dot(x1, x2, axes=None):
|
|
|
|
|
|
|
|
|
|
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
|
|
|
|
|
|
|
|
|
|
_typecheck_input(x1_type, x2_type)
|
|
|
|
|
_typecheck_input_batch_dot(x1_type, x2_type)
|
|
|
|
|
_check_batch_size(x1_batch_size, x2_batch_size)
|
|
|
|
|
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
|
|
|
|
|
|
|
|
|
|