|
|
|
@ -120,6 +120,9 @@ class GradOperation(GradOperation_):
|
|
|
|
|
""" Pynative forward run to build grad graph. """
|
|
|
|
|
if self.sens_param:
|
|
|
|
|
args = args[:-1]
|
|
|
|
|
for arg in args:
|
|
|
|
|
if not isinstance(arg, Tensor):
|
|
|
|
|
raise TypeError("grad inputs should be tensor in pynative mode")
|
|
|
|
|
if isinstance(fn, FunctionType):
|
|
|
|
|
_pynative_exec.set_grad_flag(True)
|
|
|
|
|
_pynative_exec.new_graph(fn, *args)
|
|
|
|
@ -150,9 +153,6 @@ class GradOperation(GradOperation_):
|
|
|
|
|
else:
|
|
|
|
|
@_wrap_func
|
|
|
|
|
def after_grad(*args):
|
|
|
|
|
for arg in args:
|
|
|
|
|
if not isinstance(arg, Tensor):
|
|
|
|
|
raise TypeError("grad inputs should be tensor in pynative mode")
|
|
|
|
|
self._pynative_forward_run(args, fn)
|
|
|
|
|
_pynative_exec.grad(grad_, fn, weights, *args)
|
|
|
|
|
out = _pynative_exec(*args)
|
|
|
|
|