|
|
|
@ -372,6 +372,8 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
|
|
|
|
axes[0] += len(x1_shape)
|
|
|
|
|
if axes[1] < 0:
|
|
|
|
|
axes[1] += len(x2_shape)
|
|
|
|
|
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
|
|
|
|
|
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
|
|
|
|
|
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Axes value too high for given input arrays dimensions.")
|
|
|
|
@ -380,6 +382,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
|
|
|
|
raise ValueError("Batch dim cannot be used as in axes.")
|
|
|
|
|
if axes < 0:
|
|
|
|
|
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
|
|
|
|
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
|
|
|
|
|
elif axes > len(x1_shape) or axes > len(x2_shape):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Axes value too high for given input arrays dimensions.")
|
|
|
|
@ -448,11 +451,12 @@ def batch_dot(x1, x2, axes=None):
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If type of x1 and x2 are not the same.
|
|
|
|
|
TpyeError: If dtype of x1 or x2 is not float32.
|
|
|
|
|
ValueError: If rank of x1 or x2 less than 2.
|
|
|
|
|
ValueError: If batch dim used in axes.
|
|
|
|
|
ValueError: If dtype of x1 or x2 is not float32.
|
|
|
|
|
ValueError: If len(axes) less than 2.
|
|
|
|
|
ValueError: If axes is not one of those: None, int, (int, int).
|
|
|
|
|
ValueError: If axes reversed from negative int is too low for dimensions of input arrays.
|
|
|
|
|
ValueError: If axes value is too high for dimensions of input arrays.
|
|
|
|
|
ValueError: If batch size of x1 and x2 are not the same.
|
|
|
|
|
|
|
|
|
|