diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index d72f89399e..38a3e2a5f5 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -351,13 +351,13 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat for (size_t i = 0; i < op_inputs.size(); i++) { py::object input = op_inputs[i]; if (py::hasattr(input, "__parameter__")) { - result[i] = py::getattr(input, "data"); - } else { - auto tensor = py::cast(input); - auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); - new_tensor->set_device_address(tensor->device_address()); - new_tensor->set_dirty(tensor->is_dirty()); - result[i] = new_tensor; + input = py::getattr(input, "data"); + } + auto tensor = py::cast(input); + auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); + new_tensor->set_device_address(tensor->device_address()); + new_tensor->set_dirty(tensor->is_dirty()); + result[i] = new_tensor; } } *status = PYNATIVE_SUCCESS; diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 632efa0cc1..0f28d9572f 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -120,6 +120,9 @@ class GradOperation(GradOperation_): """ Pynative forward run to build grad graph. """ if self.sens_param: args = args[:-1] + for arg in args: + if not isinstance(arg, Tensor): + raise TypeError("grad inputs should be tensor in pynative mode") if isinstance(fn, FunctionType): _pynative_exec.set_grad_flag(True) _pynative_exec.new_graph(fn, *args) @@ -150,9 +153,6 @@ class GradOperation(GradOperation_): else: @_wrap_func def after_grad(*args): - for arg in args: - if not isinstance(arg, Tensor): - raise TypeError("grad inputs should be tensor in pynative mode") self._pynative_forward_run(args, fn) _pynative_exec.grad(grad_, fn, weights, *args) out = _pynative_exec(*args)