diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 494c4e2c38..a38135c295 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -157,8 +157,6 @@ def tuple_to_array(x): def stop_gradient(x): """Implement `stop_gradient`.""" - if isinstance(x, Tensor): - return Tensor(x.asnumpy()) return x diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ed28b36a88..edb6620e14 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -314,6 +314,18 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat py::tuple err_ret(0); return std::move(err_ret); } + if (op_exec_info->op_name == "stop_gradient" && py::isinstance(result)) { + py::tuple tuple_result(1); + auto tensor = py::cast(result); + MS_EXCEPTION_IF_NULL(tensor); + auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); + new_tensor->set_device_address(tensor->device_address()); + new_tensor->set_sync_status(tensor->sync_status()); + tuple_result[0] = new_tensor; + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(tuple_result); + } // execute op py::tuple tuple_result = py::make_tuple(result);