!9612 Fix-bug-of-nullptr-in-valuenode-when-parametr-connect-with-backward-op

From: @joylvliang
Reviewed-by: @chujinjin,@jjfeing
Signed-off-by: @chujinjin
pull/9612/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e7ea724386

@ -296,57 +296,6 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
return graph_info;
}
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
MS_LOG(INFO) << "RunOpInVM start";
MS_EXCEPTION_IF_NULL(status);
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
auto &op_inputs = op_exec_info->op_inputs;
if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") {
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
auto tensor = py::cast<tensor::TensorPtr>(input);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address());
new_tensor->set_sync_status(tensor->sync_status());
result[i] = new_tensor;
}
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
}
auto primitive = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(primitive);
auto result = primitive->RunPyComputeFunction(op_inputs);
if (py::isinstance<py::none>(result)) {
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
py::tuple err_ret(0);
return std::move(err_ret);
}
if (op_exec_info->op_name == "stop_gradient" && py::isinstance<tensor::Tensor>(result)) {
py::tuple tuple_result(1);
auto tensor = py::cast<tensor::TensorPtr>(result);
MS_EXCEPTION_IF_NULL(tensor);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address());
new_tensor->set_sync_status(tensor->sync_status());
tuple_result[0] = new_tensor;
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(tuple_result);
}
// execute op
py::tuple tuple_result = py::make_tuple(result);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(tuple_result);
}
bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
const std::unordered_set<size_t> &input_attrs) {
MS_EXCEPTION_IF_NULL(op_prim);
@ -1321,6 +1270,54 @@ py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_poli
return result;
}
py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
MS_LOG(INFO) << "RunOpInVM start";
MS_EXCEPTION_IF_NULL(status);
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
auto &op_inputs = op_exec_info->op_inputs;
if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
op_exec_info->op_name == "stop_gradient") {
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
auto input_obj_id = GetId(input);
auto tensor = py::cast<tensor::TensorPtr>(input);
MS_EXCEPTION_IF_NULL(tensor);
if (obj_to_forward_id_.find(input_obj_id) == obj_to_forward_id_.end() &&
op_exec_info->op_name == "HookBackward") {
// the input object is not a output of forward cnode, eg: parameter
result[i] = tensor;
} else {
// the input object is a output of forward cnode
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address());
new_tensor->set_sync_status(tensor->sync_status());
result[i] = new_tensor;
}
}
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
}
auto primitive = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(primitive);
auto result = primitive->RunPyComputeFunction(op_inputs);
if (py::isinstance<py::none>(result)) {
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
py::tuple err_ret(0);
return std::move(err_ret);
}
// execute op
py::tuple tuple_result = py::make_tuple(result);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(tuple_result);
}
py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(status);

@ -53,8 +53,6 @@ struct PrimAbsInfo {
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args);
void ClearPyNativeSession();
@ -124,6 +122,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
PynativeStatusCode *const status);

Loading…
Cancel
Save