pull/10295/head
yefeng 5 years ago
parent 4ad573258c
commit 1391142ba2

@ -49,6 +49,11 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_
MS_LOG(ERROR) << "CheckInputs failed";
return ret;
}
MS_ASSERT(std::all_of(kernels.begin(), kernels.end(), [](kernel::LiteKernel *kernel) {
return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().end(), [](Tensor *in_tensor) {
return in_tensor->IsConst() || in_tensor->IsGraphInput() || in_tensor->ref_count() == 0;
});
}));
std::queue<kernel::LiteKernel *> kernel_queue;
for (auto kernel : kernels) {
if (kernel->IsReady(kernel->in_tensors())) {

@ -231,7 +231,7 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputNodes(const std::
auto all_input_tensors = kernel->in_tensors();
// remove all const tensor from input tensors
for (auto iter = all_input_tensors.begin(); iter != all_input_tensors.end();) {
if ((*iter)->IsConst() || (*iter)->IsGraphInput()) {
if ((*iter)->IsConst()) {
iter = all_input_tensors.erase(iter);
} else {
iter++;

@ -178,6 +178,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
if (IsContain(model_input_indices, i)) {
dst_tensor->set_category(Tensor::GRAPH_INPUT);
}
if (src_tensor->name() != nullptr) {
dst_tensor->set_tensor_name(src_tensor->name()->str());
}
this->tensors_.emplace_back(dst_tensor);
}
return RET_OK;
@ -305,6 +308,9 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
return;
}
this->output_tensor_map_.insert(std::make_pair(std::to_string(graph_out_index), out_tensor));
if (!out_tensor->tensor_name().empty()) {
this->output_tensor_map_.insert(std::make_pair(out_tensor->tensor_name(), out_tensor));
}
}
}

@ -65,6 +65,10 @@ class Tensor : public mindspore::tensor::MSTensor {
virtual bool operator==(const Tensor &tensor);
void set_tensor_name(std::string name) { tensor_name_ = name; }
std::string tensor_name() const { return tensor_name_; }
TypeId data_type() const override { return data_type_; }
void set_data_type(TypeId data_type) { data_type_ = data_type; }
@ -162,6 +166,7 @@ class Tensor : public mindspore::tensor::MSTensor {
}
protected:
std::string tensor_name_;
void *data_ = nullptr;
void *device_data_ = nullptr;
TypeId data_type_;

@ -185,7 +185,6 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input-" + std::to_string(i));
nodes_.insert(std::pair(layer.top(0), parameter));
return RET_OK;
}
}

Loading…
Cancel
Save