From 31008247034dcebe5a9297678a1654e68c34a1ab Mon Sep 17 00:00:00 2001 From: hexia Date: Wed, 29 Jul 2020 18:13:03 +0800 Subject: [PATCH] fix input --- .../session/ascend_inference_session.cc | 59 +++++++++++++------ .../session/ascend_inference_session.h | 4 +- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc index 69bb7de3cc..ee0816ee9e 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -94,25 +94,33 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vect MS_EXCEPTION_IF_NULL(kernel_graph); auto kernel_graph_inputs = kernel_graph->inputs(); size_t no_weight_input = 0; + vector paras; + // find parameters of graph inputs for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { - tensor::TensorPtr tensor = nullptr; if (!kernel_graph_inputs[i]->isa()) { MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; continue; } auto parameter = kernel_graph_inputs[i]->cast(); if (!AnfAlgo::IsParameterWeight(parameter)) { - // compare input number - if (no_weight_input >= inputs.size()) { - MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() - << "] less than that of graph."; - return false; - } - auto input = inputs[no_weight_input++]; - if (!CompareInput(input, parameter)) { - MS_LOG(ERROR) << "Please check the input information."; - return false; - } + paras.push_back(parameter); + } + } + + // check inputs + for (size_t i = 0; i < paras.size(); ++i) { + // compare input number + if (paras.size() != inputs.size()) { + MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() + << "] but the graph input number is [" << paras.size() << "]"; + MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); + return false; + } + auto input = inputs[no_weight_input++]; + if (!CompareInput(input, paras[i])) { + MS_LOG(ERROR) << "Please check the input information."; + MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); + return false; } } return true; @@ -123,12 +131,6 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const MS_EXCEPTION_IF_NULL(parameter); // compare dims auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); - if (input->shape().size() != parameter_shape.size()) { - MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size() - << ", but the parameter dim is " << parameter_shape.size() - << ". parameter : " << parameter->DebugString(); - return false; - } // compare shape auto input_shape = input->shape(); @@ -153,12 +155,31 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const return true; } -std::string AscendInferenceSession::PrintInputShape(std::vector shape) const { +template +std::string AscendInferenceSession::PrintInputShape(std::vector shape) const { string res = "["; for (auto dim : shape) { res += " " + std::to_string(dim); } return res + " ]"; } + +std::string AscendInferenceSession::InputsInfo(const std::vector ¶s, + const std::vector &inputs) const { + std::string graph = "graph inputs:{ "; + for (size_t i = 0; i < paras.size(); ++i) { + graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(paras[i], 0).size()) + + ", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(paras[i], 0)) + ", data type " + + std::to_string(AnfAlgo::GetSelectKernelBuildInfo(paras[i])->GetOutputDeviceType(0)) + " }"; + } + + std::string actual = "actual inputs:{ "; + for (size_t i = 0; i < inputs.size(); ++i) { + actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " + + PrintInputShape(inputs[i]->shape()) + ", data type " + std::to_string(inputs[i]->data_type()) + " }"; + } + return graph + " " + actual; +} + } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h index 7f4f478002..976ce7b63f 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.h +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -41,7 +41,9 @@ class AscendInferenceSession : public AscendSession { GraphId CompileGraph(NotNull func_graph) override; bool CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const override; bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const; - std::string PrintInputShape(std::vector shape) const; + template + std::string PrintInputShape(std::vector shape) const; + std::string InputsInfo(const std::vector ¶s, const std::vector &inputs) const; }; MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); } // namespace session