|
|
|
@ -149,23 +149,19 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
return op_exec_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
|
|
|
|
std::string graph_info;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
|
|
|
|
|
// get input tensor info
|
|
|
|
|
size_t input_num = op_exec_info->op_inputs.size();
|
|
|
|
|
for (size_t index = 0; index < input_num; ++index) {
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(op_exec_info->op_inputs[index])) {
|
|
|
|
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(op_exec_info->op_inputs[index]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
|
|
|
|
(void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
|
|
|
|
|
}
|
|
|
|
|
for (const auto &input_tensor : input_tensors) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_tensor);
|
|
|
|
|
(void)graph_info.append(input_tensor->GetShapeAndDataTypeInfo() + "_");
|
|
|
|
|
}
|
|
|
|
|
// get prim and abstract info
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
|
|
|
|
|
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
|
|
|
|
|
op_exec_info->abstract->ToString());
|
|
|
|
|
MS_LOG(INFO) << "Graph info [" << graph_info << "]";
|
|
|
|
|
return graph_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -337,14 +333,14 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|
|
|
|
if (session == nullptr) {
|
|
|
|
|
session = session::SessionFactory::Get().Create(device_target);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session);
|
|
|
|
|
session->Init(ms_context->device_id());
|
|
|
|
|
|
|
|
|
|
std::string graph_info = GetSingleOpGraphInfo(op_exec_info);
|
|
|
|
|
std::vector<tensor::TensorPtr> input_tensors;
|
|
|
|
|
std::vector<int> tensors_mask;
|
|
|
|
|
ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
|
|
|
|
|
// get graph info for checking it whether existing in the cache
|
|
|
|
|
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
|
|
|
|
|
session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
|
|
|
|
|
EraseValueNodeTensor(tensors_mask, &input_tensors);
|
|
|
|
|
py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
|
|
|
|
|