!7492 [ME] bug fix for check parameter

Merge pull request !7492 from chenzhongming/zomi_master
pull/7492/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1b6c2c577c

@ -75,12 +75,12 @@ rel_fns = {
rel_strs = {
# scalar compare
Rel.EQ: "equal to {}",
Rel.NE: "not equal to {}",
Rel.LT: "less than {}",
Rel.LE: "less or equal to {}",
Rel.GT: "greater than {}",
Rel.GE: "greater or equal to {}",
Rel.EQ: "== {}",
Rel.NE: "!= {}",
Rel.LT: "< {}",
Rel.LE: "<= {}",
Rel.GT: "> {}",
Rel.GE: ">= {}",
# scalar range check
Rel.INC_NEITHER: "({}, {})",
Rel.INC_LEFT: "[{}, {})",
@ -102,12 +102,16 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
type_except = TypeError if type_mismatch else ValueError
prim_name = f'in `{prim_name}`' if prim_name else ''
arg_name = f'`{arg_name}`' if arg_name else ''
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.')
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
arg_name = arg_name if arg_name else "parameter"
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, '
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
return arg_value
@ -123,7 +127,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
prim_name = f'in \'{prim_name}\'' if prim_name else ''
arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
if math.isinf(arg_value) or math.isnan(arg_value):
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
return arg_value
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
@ -137,14 +141,15 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
- number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
- number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
"""
rel_fn = Rel.get_fns(rel)
prim_name = f'in `{prim_name}`' if prim_name else ''
arg_name = f'`{arg_name}`' if arg_name else ''
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
if type_mismatch:
raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.')
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format(
raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format(
arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__))
return arg_value

Loading…
Cancel
Save