udpate the dtype check for the argmin, argmax

fix the bug for dtype check for the argmin/argmax
numel
wawltor 4 years ago committed by GitHub
parent 9b7692b144
commit 39d5bb6dce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -325,16 +325,16 @@ class TestArgMinMaxOpError(unittest.TestCase):
def test_argmax_dtype_type(): def test_argmax_dtype_type():
data = paddle.static.data( data = paddle.static.data(
name="test_argmax", shape=[10], dtype="float32") name="test_argmax", shape=[10], dtype="float32")
output = paddle.argmax(x=data, dtype=1) output = paddle.argmax(x=data, dtype=None)
self.assertRaises(TypeError, test_argmax_dtype_type) self.assertRaises(ValueError, test_argmax_dtype_type)
def test_argmin_dtype_type(): def test_argmin_dtype_type():
data = paddle.static.data( data = paddle.static.data(
name="test_argmin", shape=[10], dtype="float32") name="test_argmin", shape=[10], dtype="float32")
output = paddle.argmin(x=data, dtype=1) output = paddle.argmin(x=data, dtype=None)
self.assertRaises(TypeError, test_argmin_dtype_type) self.assertRaises(ValueError, test_argmin_dtype_type)
if __name__ == '__main__': if __name__ == '__main__':

@ -167,10 +167,10 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
"The type of 'axis' must be int or None in argmax, but received %s." "The type of 'axis' must be int or None in argmax, but received %s."
% (type(axis))) % (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): if dtype is None:
raise TypeError( raise ValueError(
"the type of 'dtype' in argmax must be str or np.dtype, but received {}". "the value of 'dtype' in argmax could not be None, but received None"
format(type(dtype))) )
var_dtype = convert_np_dtype_to_dtype_(dtype) var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
@ -245,10 +245,10 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
"The type of 'axis' must be int or None in argmin, but received %s." "The type of 'axis' must be int or None in argmin, but received %s."
% (type(axis))) % (type(axis)))
if not (isinstance(dtype, str) or isinstance(dtype, np.dtype)): if dtype is None:
raise TypeError( raise ValueError(
"the type of 'dtype' in argmin must be str or np.dtype, but received {}". "the value of 'dtype' in argmin could not be None, but received None"
format(dtype(dtype))) )
var_dtype = convert_np_dtype_to_dtype_(dtype) var_dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')

Loading…
Cancel
Save