|
|
|
@ -273,6 +273,13 @@ def _check_invalid_input(x1_shape, x2_shape):
|
|
|
|
|
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _get_transpose_shape(x2_shape):
|
|
|
|
|
x2_shape_range = tuple(range(len(x2_shape)))
|
|
|
|
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
|
|
|
|
return x2_shape_transpose
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dot(x1, x2):
|
|
|
|
|
"""
|
|
|
|
|
Computation a dot product between samples in two tensors.
|
|
|
|
@ -304,8 +311,7 @@ def dot(x1, x2):
|
|
|
|
|
_check_invalid_input(x1_shape, x2_shape)
|
|
|
|
|
|
|
|
|
|
if len(x1_shape) > 2 or len(x2_shape) > 2:
|
|
|
|
|
x2_shape_range = range(len(x2_shape))
|
|
|
|
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
|
|
|
|
x2_shape_transpose = _get_transpose_shape(x2_shape)
|
|
|
|
|
x2_transpose = transpose_op(x2, x2_shape_transpose)
|
|
|
|
|
x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
|
|
|
|
|
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
|
|
|
|
|