From a325e4f711f3a8424c606b3723ad541b0a2a1bcf Mon Sep 17 00:00:00 2001 From: simson <526422051@qq.com> Date: Thu, 10 Sep 2020 21:05:59 +0800 Subject: [PATCH] add param check for gradoperation --- mindspore/ops/composite/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 59d6bbfe54..0666174cdd 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -284,6 +284,12 @@ class GradOperation(GradOperation_): """ def __init__(self, get_all=False, get_by_list=False, sens_param=False): + if not isinstance(get_all, bool): + raise TypeError(f'get_all should be bool, but got {type(get_all)}') + if not isinstance(get_by_list, bool): + raise TypeError(f'get_by_list should be bool, but got {type(get_by_list)}') + if not isinstance(sens_param, bool): + raise TypeError(f'sens_param should be bool, but got {type(sens_param)}') self.get_all = get_all self.get_by_list = get_by_list self.sens_param = sens_param