|
|
|
@ -209,8 +209,8 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
if not hasattr(cls, "op_type"):
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
"This test do not have op_type in class attrs,"
|
|
|
|
|
" please set self.__class__.op_type=the_real_op_type manually.")
|
|
|
|
|
"This test do not have op_type in class attrs, "
|
|
|
|
|
"please set self.__class__.op_type=the_real_op_type manually.")
|
|
|
|
|
|
|
|
|
|
# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
|
|
|
|
|
if not hasattr(cls, "no_need_check_grad") \
|
|
|
|
@ -222,9 +222,11 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
raise AssertionError("This test of %s op needs check_grad." %
|
|
|
|
|
cls.op_type)
|
|
|
|
|
|
|
|
|
|
# check for op test with fp64 precision, but not check mkldnn op test for now
|
|
|
|
|
if cls.dtype in [np.float32, np.float64] \
|
|
|
|
|
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
|
|
|
|
|
and not hasattr(cls, 'exist_fp64_check_grad'):
|
|
|
|
|
and not hasattr(cls, 'exist_fp64_check_grad') \
|
|
|
|
|
and (not hasattr(cls, "use_mkldnn") or cls.use_mkldnn == False):
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
"This test of %s op needs check_grad with fp64 precision." %
|
|
|
|
|
cls.op_type)
|
|
|
|
|