From a9752ea5e1301597fcd5508cdec1b559648b12f8 Mon Sep 17 00:00:00 2001 From: simson Date: Tue, 9 Mar 2021 17:25:02 +0800 Subject: [PATCH] add valuenode info to graph_info --- mindspore/ccsrc/pipeline/pynative/base.h | 2 +- .../pipeline/pynative/pynative_execute.cc | 34 +++++++++++++------ .../pipeline/pynative/pynative_execute.h | 4 +-- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index b87b169796..ad07595887 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -57,7 +57,7 @@ struct OpExecInfo { py::list op_inputs; py::dict op_attrs; - std::vector inputs_mask; + std::vector inputs_mask; bool is_dynamic_shape = false; std::string next_op_name = ""; bool is_mixed_precision_cast = false; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 24641e2cd9..bb2aa68cde 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -266,17 +266,26 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, MS_EXCEPTION_IF_NULL(op_exec_info); std::string graph_info; // get input tensor info - for (const auto &tensor : input_tensors) { - MS_EXCEPTION_IF_NULL(tensor); - auto tensor_shape = tensor->shape(); + for (size_t index = 0; index < input_tensors.size(); ++index) { + MS_EXCEPTION_IF_NULL(input_tensors[index]); + auto tensor_shape = input_tensors[index]->shape(); (void)std::for_each(tensor_shape.begin(), tensor_shape.end(), [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); - (void)graph_info.append(std::to_string(tensor->data_type()) + "_"); - if (tensor->device_address() != nullptr) { - (void)graph_info.append( - std::to_string(std::dynamic_pointer_cast(tensor->device_address())->type_id()) + "_"); - (void)graph_info.append(std::dynamic_pointer_cast(tensor->device_address())->format() + + (void)graph_info.append(std::to_string(input_tensors[index]->data_type()) + "_"); + auto tensor_addr = input_tensors[index]->device_address(); + if (tensor_addr != nullptr) { + (void)graph_info.append(std::to_string(std::dynamic_pointer_cast(tensor_addr)->type_id()) + "_"); + (void)graph_info.append(std::dynamic_pointer_cast(tensor_addr)->format() + "_"); + } + if (static_cast(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) { + if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) { + (void)graph_info.append(std::to_string(*reinterpret_cast(input_tensors[index]->data_c())) + "_"); + } else if (input_tensors[index]->Dtype()->type_id() == kNumberTypeFloat32) { + (void)graph_info.append(std::to_string(*reinterpret_cast(input_tensors[index]->data_c())) + "_"); + } else { + MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!"; + } } } // get prim and abstract info @@ -387,8 +396,10 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr } else if (py::isinstance(input_object)) { double input_value = py::cast(input_object); tensor_ptr = std::make_shared(input_value, kFloat32); + *tensor_mask = kValueNodeTensorMask; } else if (py::isinstance(input_object)) { tensor_ptr = std::make_shared(py::cast(input_object), kInt64); + *tensor_mask = kValueNodeTensorMask; } else if (py::isinstance(input_object)) { tensor_ptr = TensorPy::MakeTensor(py::cast(input_object), nullptr); } else if (py::isinstance(input_object)) { @@ -452,6 +463,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector int64_t tensor_mask = static_cast(op_run_info->inputs_mask[index]); ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); // mark tensors, data : 0, weight : 1, valuenode: 2 + op_run_info->inputs_mask[index] = tensor_mask; std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); } @@ -602,7 +614,7 @@ py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { } // make cnode for building grad graph if grad flag is set. abstract::AbstractBasePtrList args_spec_list; - std::vector op_masks; + std::vector op_masks; auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list); op_exec_info->inputs_mask = op_masks; // get output abstract info @@ -677,7 +689,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { return op_exec_info; } -void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, +void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, std::vector *inputs, abstract::AbstractBasePtrList *args_spec_list) { auto prim = op_exec_info->py_primitive; for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) { @@ -715,7 +727,7 @@ void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vecto } } -AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, +AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list) { MS_EXCEPTION_IF_NULL(op_masks); MS_EXCEPTION_IF_NULL(args_spec_list); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index de3e4dcf91..8df1a77224 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -208,9 +208,9 @@ class PynativeExecutor : public std::enable_shared_from_this { PynativeStatusCode *const status); AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); - void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, std::vector *inputs, + void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, std::vector *inputs, abstract::AbstractBasePtrList *args_spec_list); - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);