From 3fc2c16fea2167fe2965d2ab8568ee2ceee0911c Mon Sep 17 00:00:00 2001 From: w00535372 Date: Sat, 27 Mar 2021 14:26:39 +0800 Subject: [PATCH] Add input check for axes which is float type or out of bound --- mindspore/ops/composite/math_ops.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index 04c60644ac..d6f197b3a2 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): raise ValueError("Require two axes inputs, given less") if isinstance(axes, tuple): axes = list(axes) - for sub_axes in axes: - if isinstance(sub_axes, (list, tuple)): - raise ValueError("Require dimension to be in any of those: None, int, (int, int).") + validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot') + validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot') # Reverse if axis < 0 if axes[0] < 0: axes[0] += len(x1_shape) if axes[1] < 0: axes[1] += len(x2_shape) + if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): + raise ValueError( + "Axes value too high for given input arrays dimensions.") elif isinstance(axes, int): if axes == 0: raise ValueError("Batch dim cannot be used as in axes.") @@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None): """ Computation of batch dot product between samples in two tensors containing batch dims. + .. math:: + output = x1[batch, :] * x2[batch, :] + Inputs: - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should @@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None): Outputs: Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and - (batch, d3, axes, d4) is (batch, d1, d2, d3, d4) - - .. math:: - output = x1[batch, :] * x2[batch, :] + (batch, d3, axes, d4) is (batch, d1, d2, d3, d4) Raises: - TypeError: If shapes of x1 and x2 are not the same. + TypeError: If type of x1 and x2 are not the same. ValueError: If rank of x1 or x2 less than 2. ValueError: If batch dim used in axes. ValueError: If dtype of x1 or x2 is not float32.