|
|
@ -210,6 +210,9 @@ class OpTest(unittest.TestCase):
|
|
|
|
def is_mkldnn_op_test():
|
|
|
|
def is_mkldnn_op_test():
|
|
|
|
return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
|
|
|
|
return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_ngraph_op_test():
|
|
|
|
|
|
|
|
return hasattr(cls, "use_ngraph") and cls.use_ngraph == True
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(cls, "op_type"):
|
|
|
|
if not hasattr(cls, "op_type"):
|
|
|
|
raise AssertionError(
|
|
|
|
raise AssertionError(
|
|
|
|
"This test do not have op_type in class attrs, "
|
|
|
|
"This test do not have op_type in class attrs, "
|
|
|
@ -229,6 +232,7 @@ class OpTest(unittest.TestCase):
|
|
|
|
if cls.dtype in [np.float32, np.float64] \
|
|
|
|
if cls.dtype in [np.float32, np.float64] \
|
|
|
|
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
|
|
|
|
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
|
|
|
|
and not hasattr(cls, 'exist_fp64_check_grad') \
|
|
|
|
and not hasattr(cls, 'exist_fp64_check_grad') \
|
|
|
|
|
|
|
|
and not is_ngraph_op_test() \
|
|
|
|
and not is_mkldnn_op_test():
|
|
|
|
and not is_mkldnn_op_test():
|
|
|
|
raise AssertionError(
|
|
|
|
raise AssertionError(
|
|
|
|
"This test of %s op needs check_grad with fp64 precision." %
|
|
|
|
"This test of %s op needs check_grad with fp64 precision." %
|
|
|
@ -320,6 +324,10 @@ class OpTest(unittest.TestCase):
|
|
|
|
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
|
|
|
|
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
|
|
|
|
self.attrs["use_mkldnn"] == True):
|
|
|
|
self.attrs["use_mkldnn"] == True):
|
|
|
|
self.__class__.use_mkldnn = True
|
|
|
|
self.__class__.use_mkldnn = True
|
|
|
|
|
|
|
|
if fluid.core.is_compiled_with_ngraph() and \
|
|
|
|
|
|
|
|
fluid.core.globals()['FLAGS_use_ngraph']:
|
|
|
|
|
|
|
|
self.__class__.use_ngraph = True
|
|
|
|
|
|
|
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
|
|
|
|
"infer datatype from inputs and outputs for this test case"
|
|
|
|
"infer datatype from inputs and outputs for this test case"
|
|
|
|
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
|
|
|
|
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
|
|
|
@ -936,14 +944,16 @@ class OpTest(unittest.TestCase):
|
|
|
|
attrs_use_mkldnn = hasattr(
|
|
|
|
attrs_use_mkldnn = hasattr(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
'attrs') and bool(self.attrs.get('use_mkldnn', False))
|
|
|
|
'attrs') and bool(self.attrs.get('use_mkldnn', False))
|
|
|
|
|
|
|
|
flags_use_ngraph = fluid.core.globals()["FLAGS_use_ngraph"]
|
|
|
|
|
|
|
|
attrs_use_ngraph = hasattr(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
'attrs') and bool(self.attrs.get('use_ngraph', False))
|
|
|
|
if flags_use_mkldnn or attrs_use_mkldnn:
|
|
|
|
if flags_use_mkldnn or attrs_use_mkldnn:
|
|
|
|
warnings.warn(
|
|
|
|
warnings.warn(
|
|
|
|
"check inplace_grad for ops using mkldnn is not supported"
|
|
|
|
"check inplace_grad for ops using mkldnn is not supported"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
use_ngraph = fluid.core.is_compiled_with_ngraph(
|
|
|
|
if flags_use_ngraph or attrs_use_ngraph:
|
|
|
|
) and fluid.core.globals()["FLAGS_use_ngraph"]
|
|
|
|
|
|
|
|
if use_ngraph:
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
warnings.warn(
|
|
|
|
"check inplace_grad for ops using ngraph is not supported"
|
|
|
|
"check inplace_grad for ops using ngraph is not supported"
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -1190,6 +1200,10 @@ class OpTest(unittest.TestCase):
|
|
|
|
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
|
|
|
|
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
|
|
|
|
self.attrs["use_mkldnn"] == True):
|
|
|
|
self.attrs["use_mkldnn"] == True):
|
|
|
|
self.__class__.use_mkldnn = True
|
|
|
|
self.__class__.use_mkldnn = True
|
|
|
|
|
|
|
|
if fluid.core.is_compiled_with_ngraph() and \
|
|
|
|
|
|
|
|
fluid.core.globals()['FLAGS_use_ngraph']:
|
|
|
|
|
|
|
|
self.__class__.use_ngraph = True
|
|
|
|
|
|
|
|
|
|
|
|
places = self._get_places()
|
|
|
|
places = self._get_places()
|
|
|
|
for place in places:
|
|
|
|
for place in places:
|
|
|
|
res = self.check_output_with_place(place, atol, no_check_set,
|
|
|
|
res = self.check_output_with_place(place, atol, no_check_set,
|
|
|
|