!14684 Add additional input check for batch dot op

From: @anrui-wang
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
pull/14684/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 89eeb256dc

@ -372,6 +372,8 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
axes[0] += len(x1_shape)
if axes[1] < 0:
axes[1] += len(x2_shape)
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
@ -380,6 +382,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
raise ValueError("Batch dim cannot be used as in axes.")
if axes < 0:
axes = [axes + len(x1_shape), axes + len(x2_shape)]
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
elif axes > len(x1_shape) or axes > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
@ -448,11 +451,12 @@ def batch_dot(x1, x2, axes=None):
Raises:
TypeError: If type of x1 and x2 are not the same.
TpyeError: If dtype of x1 or x2 is not float32.
ValueError: If rank of x1 or x2 less than 2.
ValueError: If batch dim used in axes.
ValueError: If dtype of x1 or x2 is not float32.
ValueError: If len(axes) less than 2.
ValueError: If axes is not one of those: None, int, (int, int).
ValueError: If axes reversed from negative int is too low for dimensions of input arrays.
ValueError: If axes value is too high for dimensions of input arrays.
ValueError: If batch size of x1 and x2 are not the same.

Loading…
Cancel
Save