|
|
|
@ -94,25 +94,33 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vect
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto kernel_graph_inputs = kernel_graph->inputs();
|
|
|
|
|
size_t no_weight_input = 0;
|
|
|
|
|
vector<ParameterPtr> paras;
|
|
|
|
|
// find parameters of graph inputs
|
|
|
|
|
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
|
|
|
|
|
tensor::TensorPtr tensor = nullptr;
|
|
|
|
|
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
|
|
|
|
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
|
|
|
|
// compare input number
|
|
|
|
|
if (no_weight_input >= inputs.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
|
|
|
|
|
<< "] less than that of graph.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto input = inputs[no_weight_input++];
|
|
|
|
|
if (!CompareInput(input, parameter)) {
|
|
|
|
|
MS_LOG(ERROR) << "Please check the input information.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
paras.push_back(parameter);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// check inputs
|
|
|
|
|
for (size_t i = 0; i < paras.size(); ++i) {
|
|
|
|
|
// compare input number
|
|
|
|
|
if (paras.size() != inputs.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
|
|
|
|
|
<< "] but the graph input number is [" << paras.size() << "]";
|
|
|
|
|
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto input = inputs[no_weight_input++];
|
|
|
|
|
if (!CompareInput(input, paras[i])) {
|
|
|
|
|
MS_LOG(ERROR) << "Please check the input information.";
|
|
|
|
|
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
@ -123,12 +131,6 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
// compare dims
|
|
|
|
|
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
|
|
|
|
if (input->shape().size() != parameter_shape.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size()
|
|
|
|
|
<< ", but the parameter dim is " << parameter_shape.size()
|
|
|
|
|
<< ". parameter : " << parameter->DebugString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compare shape
|
|
|
|
|
auto input_shape = input->shape();
|
|
|
|
@ -153,12 +155,31 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) const {
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::string AscendInferenceSession::PrintInputShape(std::vector<T> shape) const {
|
|
|
|
|
string res = "[";
|
|
|
|
|
for (auto dim : shape) {
|
|
|
|
|
res += " " + std::to_string(dim);
|
|
|
|
|
}
|
|
|
|
|
return res + " ]";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string AscendInferenceSession::InputsInfo(const std::vector<ParameterPtr> ¶s,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &inputs) const {
|
|
|
|
|
std::string graph = "graph inputs:{ ";
|
|
|
|
|
for (size_t i = 0; i < paras.size(); ++i) {
|
|
|
|
|
graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(paras[i], 0).size()) +
|
|
|
|
|
", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(paras[i], 0)) + ", data type " +
|
|
|
|
|
std::to_string(AnfAlgo::GetSelectKernelBuildInfo(paras[i])->GetOutputDeviceType(0)) + " }";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string actual = "actual inputs:{ ";
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
|
|
|
actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " +
|
|
|
|
|
PrintInputShape(inputs[i]->shape()) + ", data type " + std::to_string(inputs[i]->data_type()) + " }";
|
|
|
|
|
}
|
|
|
|
|
return graph + " " + actual;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace session
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|