From cd3daba7cbb0f99816e74f4885f123bade4b063e Mon Sep 17 00:00:00 2001 From: lixian Date: Tue, 2 Feb 2021 14:22:04 +0800 Subject: [PATCH] syncronize issues --- mindspore/lite/src/cxx_api/lite_context.cc | 4 +- mindspore/lite/src/cxx_api/model/model.cc | 11 +-- .../lite/src/cxx_api/model/model_impl.cc | 72 ++++++++++++++----- .../lite/src/cxx_api/tensor/tensor_impl.cc | 8 +-- mindspore/lite/src/cxx_api/utils.h | 24 ++++--- 5 files changed, 80 insertions(+), 39 deletions(-) diff --git a/mindspore/lite/src/cxx_api/lite_context.cc b/mindspore/lite/src/cxx_api/lite_context.cc index ac1aa80aa9..cc5f646a01 100644 --- a/mindspore/lite/src/cxx_api/lite_context.cc +++ b/mindspore/lite/src/cxx_api/lite_context.cc @@ -282,9 +282,9 @@ void Context::SetNPUFrequency(const std::shared_ptr &context, int freq) } auto iter = context->context_.find(kNPUFrequency); if (iter != context->context_.end()) { - iter->second = true; + iter->second = freq; } else { - context->context_.emplace(kNPUFrequency, true); + context->context_.emplace(kNPUFrequency, freq); } } diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index 7d564c2410..f94aa367c6 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -48,17 +48,12 @@ Model::Model(const GraphCell &graph, const std::shared_ptr &model_conte impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); if (impl_ == nullptr || graph.GetGraph() == nullptr) { MS_LOG(ERROR) << "Invalid graph."; + } else if (model_context == nullptr) { + MS_LOG(ERROR) << "Invalid context."; } else { - if (model_context == nullptr) { - MS_LOG(INFO) << "Invalid context, use default context."; - auto context = std::shared_ptr(new (std::nothrow) Context()); - Context::SetAsDefault(context); - impl_->SetContext(context); - } else { - impl_->SetContext(model_context); - } auto new_graph_cell = std::shared_ptr(new (std::nothrow) GraphCell(graph)); if (new_graph_cell != nullptr) { + impl_->SetContext(model_context); impl_->SetGraphCell(new_graph_cell); } else { MS_LOG(ERROR) << "New graphcell failed."; diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index 989c8e3897..0aa48361a3 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -71,7 +71,8 @@ Status ModelImpl::Build() { model_context.thread_num_ = Context::GetThreadNum(context_); model_context.device_list_.clear(); if (Context::IfCPUEnabled(context_) && Context::IfGPUEnabled(context_) && Context::IfNPUEnabled(context_)) { - MS_LOG(INFO) << "CPU/GPU/NPU cannot be enabled at the same time."; + MS_LOG(ERROR) << "CPU/GPU/NPU cannot be enabled at the same time."; + return kLiteInputParamInvalid; } if (!Context::IfCPUEnabled(context_)) { MS_LOG(INFO) << "CPU is forced to be enabled."; @@ -155,6 +156,7 @@ Status ModelImpl::Predict(const std::vector &inputs, std::vectorclear(); outputs->insert(outputs->end(), res.begin(), res.end()); return kSuccess; } @@ -167,8 +169,13 @@ std::vector ModelImpl::GetInputs() { } std::vector res; auto inputs = session_->GetInputs(); - for (auto input : inputs) { - auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(input)); + if (inputs.empty()) { + MS_LOG(ERROR) << "The inputs of model is null."; + return empty; + } + res.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(inputs[i])); if (impl == nullptr) { MS_LOG(ERROR) << "Create tensor failed."; return empty; @@ -178,7 +185,7 @@ std::vector ModelImpl::GetInputs() { MS_LOG(ERROR) << "Create tensor failed."; return empty; } - res.push_back(tensor); + res[i] = tensor; } return res; } @@ -191,9 +198,22 @@ std::vector ModelImpl::GetOutputs() { } std::vector res; auto names = session_->GetOutputTensorNames(); + if (names.empty()) { + MS_LOG(ERROR) << "The names of model is null."; + return empty; + } auto outputs = session_->GetOutputs(); - for (auto name : names) { - auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(outputs[name])); + if (outputs.empty()) { + MS_LOG(ERROR) << "The outputs of model is null."; + return empty; + } + if (names.size() != outputs.size()) { + MS_LOG(ERROR) << "The size of outputs dose not match the size of names."; + return empty; + } + res.resize(names.size()); + for (size_t i = 0; i < names.size(); i++) { + auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(outputs[names[i]])); if (impl == nullptr) { MS_LOG(ERROR) << "Create tensor failed."; return empty; @@ -203,7 +223,7 @@ std::vector ModelImpl::GetOutputs() { MS_LOG(ERROR) << "Create tensor failed."; return empty; } - res.push_back(tensor); + res[i] = tensor; } return res; } @@ -213,26 +233,44 @@ Status ModelImpl::Resize(const std::vector &inputs, const std::vector< MS_LOG(ERROR) << "Session is null."; return kLiteNullptr; } + if (inputs.empty()) { + MS_LOG(ERROR) << "Inputs is null."; + return kLiteInputParamInvalid; + } + if (dims.empty()) { + MS_LOG(ERROR) << "Dims is null."; + return kLiteInputParamInvalid; + } if (inputs.size() != dims.size()) { - MS_LOG(ERROR) << "The size of inputs is not equal to the size of dims."; + MS_LOG(ERROR) << "The size of inputs does not match the size of dims."; + return kLiteInputParamInvalid; + } + auto model_inputs = session_->GetInputs(); + if (model_inputs.empty()) { + MS_LOG(ERROR) << "The inputs of model is null."; return kLiteParamInvalid; } + if (inputs.size() != model_inputs.size()) { + MS_LOG(ERROR) << "The size of inputs is incorrect."; + return kLiteInputParamInvalid; + } std::vector inner_input; - for (auto input : inputs) { + inner_input.resize(inputs.size()); + std::vector> truncated_shape; + truncated_shape.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; if (input.impl_ == nullptr || input.impl_->lite_tensor() == nullptr) { MS_LOG(ERROR) << "Input tensor " << input.Name() << " is null."; return kLiteInputTensorError; } - inner_input.push_back(input.impl_->lite_tensor()); - } - std::vector> truncated_shape; - for (size_t i = 0; i < inner_input.size(); i++) { - std::vector tmp = TruncateShape(dims.at(i), inner_input.at(i)->data_type(), inner_input.at(i)->Size()); - if (tmp.empty()) { - MS_LOG(ERROR) << "Input dims[" << i << "]is invalid."; + inner_input[i] = input.impl_->lite_tensor(); + std::vector shape = TruncateShape(dims[i], inner_input[i]->data_type(), inner_input[i]->Size(), false); + if (shape.empty() && !(dims[i].empty())) { + MS_LOG(ERROR) << "Input dims[" << i << "] is invalid."; return kLiteParamInvalid; } - truncated_shape.push_back(tmp); + truncated_shape[i] = shape; } auto ret = session_->Resize(inner_input, truncated_shape); return static_cast(ret); diff --git a/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc b/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc index 41a430bce5..a0ec1677ba 100644 --- a/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc +++ b/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc @@ -28,11 +28,11 @@ namespace mindspore { MSTensor::Impl::Impl(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) { - std::vector truncated_shape = TruncateShape(shape, static_cast(type), data_len); - if (!truncated_shape.empty()) { - lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast(type), truncated_shape, data); - } else { + std::vector truncated_shape = TruncateShape(shape, static_cast(type), data_len, true); + if (truncated_shape.empty() && !(shape.empty())) { lite_tensor_ = nullptr; + } else { + lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast(type), truncated_shape, data); } } diff --git a/mindspore/lite/src/cxx_api/utils.h b/mindspore/lite/src/cxx_api/utils.h index 03a6c5a5c5..714709ab8c 100644 --- a/mindspore/lite/src/cxx_api/utils.h +++ b/mindspore/lite/src/cxx_api/utils.h @@ -18,22 +18,30 @@ #include "src/tensor.h" namespace mindspore { -static std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len) { +static std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, + bool verify_size) { std::vector empty; + if (shape.empty()) { + return empty; + } std::vector truncated_shape; + truncated_shape.resize(shape.size()); size_t element_size = lite::DataTypeSize(type); - for (auto i : shape) { - if (i < 0 || i > INT_MAX || element_size > INT_MAX / static_cast(i)) { + for (size_t i = 0; i < shape.size(); i++) { + auto dim = shape[i]; + if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast(dim)) { MS_LOG(ERROR) << "Invalid shape."; return empty; } else { - element_size *= static_cast(i); - truncated_shape.push_back(static_cast(i)); + element_size *= static_cast(dim); + truncated_shape[i] = static_cast(dim); } } - if (element_size != data_len) { - MS_LOG(ERROR) << "Invalid data size."; - return empty; + if (verify_size) { + if (element_size != data_len) { + MS_LOG(ERROR) << "Invalid data size."; + return empty; + } } return truncated_shape; }