|
|
|
|
@ -70,8 +70,6 @@ def get_numeric_gradient(place,
|
|
|
|
|
|
|
|
|
|
tensor_to_check = scope.find_var(input_to_check).get_tensor()
|
|
|
|
|
tensor_size = product(tensor_to_check.shape())
|
|
|
|
|
if tensor_size < 100:
|
|
|
|
|
get_numeric_gradient.is_large_shape = False
|
|
|
|
|
tensor_to_check_dtype = tensor_to_check._dtype()
|
|
|
|
|
if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
tensor_to_check_dtype = np.float32
|
|
|
|
|
@ -178,14 +176,13 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
cls.call_once = False
|
|
|
|
|
cls.dtype = None
|
|
|
|
|
cls.outputs = {}
|
|
|
|
|
cls.input_shape_is_large = True
|
|
|
|
|
|
|
|
|
|
np.random.seed(123)
|
|
|
|
|
random.seed(124)
|
|
|
|
|
|
|
|
|
|
cls._use_system_allocator = _set_use_system_allocator(True)
|
|
|
|
|
|
|
|
|
|
get_numeric_gradient.is_large_shape = True
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def tearDownClass(cls):
|
|
|
|
|
"""Restore random seeds"""
|
|
|
|
|
@ -238,7 +235,7 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
"This test of %s op needs check_grad with fp64 precision." %
|
|
|
|
|
cls.op_type)
|
|
|
|
|
|
|
|
|
|
if not get_numeric_gradient.is_large_shape \
|
|
|
|
|
if not cls.input_shape_is_large \
|
|
|
|
|
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
"Input's shape should be large than or equal to 100 for " +
|
|
|
|
|
@ -1319,6 +1316,14 @@ class OpTest(unittest.TestCase):
|
|
|
|
|
raise AssertionError("no_grad_set must be None, op_type is " +
|
|
|
|
|
self.op_type + " Op.")
|
|
|
|
|
|
|
|
|
|
for input_to_check in inputs_to_check:
|
|
|
|
|
set_input(self.scope, self.op, self.inputs, place)
|
|
|
|
|
tensor_to_check = self.scope.find_var(input_to_check).get_tensor()
|
|
|
|
|
tensor_size = six.moves.reduce(lambda a, b: a * b,
|
|
|
|
|
tensor_to_check.shape(), 1)
|
|
|
|
|
if tensor_size < 100:
|
|
|
|
|
self.__class__.input_shape_is_large = False
|
|
|
|
|
|
|
|
|
|
if not type(output_names) is list:
|
|
|
|
|
output_names = [output_names]
|
|
|
|
|
|
|
|
|
|
|