diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 59d6bbfe54..0666174cdd 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -284,6 +284,12 @@ class GradOperation(GradOperation_): """ def __init__(self, get_all=False, get_by_list=False, sens_param=False): + if not isinstance(get_all, bool): + raise TypeError(f'get_all should be bool, but got {type(get_all)}') + if not isinstance(get_by_list, bool): + raise TypeError(f'get_by_list should be bool, but got {type(get_by_list)}') + if not isinstance(sens_param, bool): + raise TypeError(f'sens_param should be bool, but got {type(sens_param)}') self.get_all = get_all self.get_by_list = get_by_list self.sens_param = sens_param