|
|
|
@ -286,7 +286,7 @@ class GradientChecker(unittest.TestCase):
|
|
|
|
|
for no_grad in no_grad_set:
|
|
|
|
|
if no_grad not in in_names:
|
|
|
|
|
raise ValueError("no_grad should be in in_names")
|
|
|
|
|
if name in inputs_to_check:
|
|
|
|
|
if no_grad in inputs_to_check:
|
|
|
|
|
raise ValueError("no_grad should not be in inputs_to_check")
|
|
|
|
|
|
|
|
|
|
backward_op = core.Operator.backward(forward_op, no_grad_set)
|
|
|
|
@ -304,25 +304,8 @@ class GradientChecker(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
check_names = [grad_var_name(name) for name in inputs_to_check]
|
|
|
|
|
for place in places:
|
|
|
|
|
# analytic_grads = self.__get_gradient(forward_op, backward_op,
|
|
|
|
|
# input_vars, check_names, place)
|
|
|
|
|
# In fact, the above two lines can be used to replace following
|
|
|
|
|
# codes. But most of the gradient operators need to handle the case
|
|
|
|
|
# where one of more of the gradient of the input is not needed.
|
|
|
|
|
# We change the unit test framework to explicitly test whether
|
|
|
|
|
# the operator correctly handles this through follow codes.
|
|
|
|
|
# In addtion, if all the inputs have no gradients, the NOP operator
|
|
|
|
|
# will be returned by core.Operator.backward(). The following codes
|
|
|
|
|
# do not test this case.
|
|
|
|
|
analytic_grads = []
|
|
|
|
|
for name in inputs_to_check:
|
|
|
|
|
no_grads = [name for name in no_grad_set]
|
|
|
|
|
no_grads.extend(filter(lambda x: x != name, inputs_to_check))
|
|
|
|
|
backward_op = core.Operator.backward(forward_op, set(no_grads))
|
|
|
|
|
# get analytical gradients according to different device
|
|
|
|
|
analytic_grads.extend(
|
|
|
|
|
self.__get_gradient(forward_op, backward_op, input_vars,
|
|
|
|
|
[grad_var_name(name)], place))
|
|
|
|
|
analytic_grads = self.__get_gradient(forward_op, backward_op,
|
|
|
|
|
input_vars, check_names, place)
|
|
|
|
|
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
|
|
|
|
|
max_relative_error,
|
|
|
|
|
"Gradient Check On %s" % str(place))
|
|
|
|
|