!14227 Fix input error when axes is float or out of bound

From: @anrui-wang
Reviewed-by: @liangchenghui,@liangchenghui
Signed-off-by: @liangchenghui,@liangchenghui
pull/14227/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 80b56441f0

@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
raise ValueError("Require two axes inputs, given less")
if isinstance(axes, tuple):
axes = list(axes)
for sub_axes in axes:
if isinstance(sub_axes, (list, tuple)):
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
# Reverse if axis < 0
if axes[0] < 0:
axes[0] += len(x1_shape)
if axes[1] < 0:
axes[1] += len(x2_shape)
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
elif isinstance(axes, int):
if axes == 0:
raise ValueError("Batch dim cannot be used as in axes.")
@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None):
"""
Computation of batch dot product between samples in two tensors containing batch dims.
.. math::
output = x1[batch, :] * x2[batch, :]
Inputs:
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None):
Outputs:
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
.. math::
output = x1[batch, :] * x2[batch, :]
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
Raises:
TypeError: If shapes of x1 and x2 are not the same.
TypeError: If type of x1 and x2 are not the same.
ValueError: If rank of x1 or x2 less than 2.
ValueError: If batch dim used in axes.
ValueError: If dtype of x1 or x2 is not float32.

Loading…
Cancel
Save