|
|
|
@ -198,7 +198,7 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
all_op_kernels = core._get_all_register_op_kernels()
|
|
|
|
|
grad_op = op_type + '_grad'
|
|
|
|
|
if grad_op in all_op_kernels.keys():
|
|
|
|
|
if hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True:
|
|
|
|
|
if is_mkldnn_op_test():
|
|
|
|
|
grad_op_kernels = all_op_kernels[grad_op]
|
|
|
|
|
for grad_op_kernel in grad_op_kernels:
|
|
|
|
|
if 'MKLDNN' in grad_op_kernel:
|
|
|
|
@ -207,6 +207,14 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def is_mkldnn_op_test():
|
|
|
|
|
if (hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True) or \
|
|
|
|
|
(hasattr(cls, "attrs") and "use_mkldnn" in cls.attrs and \
|
|
|
|
|
cls.attrs["use_mkldnn"] == True):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if not hasattr(cls, "op_type"):
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
"This test do not have op_type in class attrs, "
|
|
|
|
@ -226,7 +234,7 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
if cls.dtype in [np.float32, np.float64] \
|
|
|
|
|
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
|
|
|
|
|
and not hasattr(cls, 'exist_fp64_check_grad') \
|
|
|
|
|
and (not hasattr(cls, "use_mkldnn") or cls.use_mkldnn == False):
|
|
|
|
|
and not is_mkldnn_op_test():
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
"This test of %s op needs check_grad with fp64 precision." %
|
|
|
|
|
cls.op_type)
|
|
|
|
|