|
|
|
@ -208,12 +208,7 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def is_mkldnn_op_test():
|
|
|
|
|
if (hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True) or \
|
|
|
|
|
(hasattr(cls, "attrs") and "use_mkldnn" in cls.attrs and \
|
|
|
|
|
cls.attrs["use_mkldnn"] == True):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
|
|
|
|
|
|
|
|
|
|
if not hasattr(cls, "op_type"):
|
|
|
|
|
raise AssertionError(
|
|
|
|
@ -321,8 +316,10 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def _append_ops(self, block):
|
|
|
|
|
self.__class__.op_type = self.op_type # for ci check, please not delete it for now
|
|
|
|
|
if hasattr(self, "use_mkldnn"):
|
|
|
|
|
self.__class__.use_mkldnn = self.use_mkldnn
|
|
|
|
|
if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \
|
|
|
|
|
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
|
|
|
|
|
self.attrs["use_mkldnn"] == True):
|
|
|
|
|
self.__class__.use_mkldnn = True
|
|
|
|
|
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
|
|
|
|
|
"infer datatype from inputs and outputs for this test case"
|
|
|
|
|
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
|
|
|
|
@ -1189,8 +1186,10 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
check_dygraph=True,
|
|
|
|
|
inplace_atol=None):
|
|
|
|
|
self.__class__.op_type = self.op_type
|
|
|
|
|
if hasattr(self, "use_mkldnn"):
|
|
|
|
|
self.__class__.use_mkldnn = self.use_mkldnn
|
|
|
|
|
if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \
|
|
|
|
|
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
|
|
|
|
|
self.attrs["use_mkldnn"] == True):
|
|
|
|
|
self.__class__.use_mkldnn = True
|
|
|
|
|
places = self._get_places()
|
|
|
|
|
for place in places:
|
|
|
|
|
res = self.check_output_with_place(place, atol, no_check_set,
|
|
|
|
|