|
|
@ -115,7 +115,8 @@ STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
|
|
|
|
// parse inputs
|
|
|
|
// parse inputs
|
|
|
|
for (auto input_idx : op->inputs) {
|
|
|
|
for (int i = 0; i < static_cast<int>(op->inputs.size()); i++) {
|
|
|
|
|
|
|
|
auto input_idx = op->inputs.at(i);
|
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) {
|
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -127,9 +128,27 @@ STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
op_inputs.emplace_back(nodes_.at(input_idx));
|
|
|
|
op_inputs.emplace_back(nodes_.at(input_idx));
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// const tensor
|
|
|
|
// const tensor
|
|
|
|
|
|
|
|
std::string tensor_name;
|
|
|
|
|
|
|
|
if (!input_tensor->name.empty()) {
|
|
|
|
|
|
|
|
tensor_name = input_tensor->name;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
tensor_name = op_name + "/input-" + std::to_string(op_inputs.size());
|
|
|
|
|
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
|
|
|
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV ||
|
|
|
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
|
|
|
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
|
|
|
|
|
|
|
|
if (i == 1) {
|
|
|
|
|
|
|
|
tensor_name = op_name + "/weight";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (i == 2) {
|
|
|
|
|
|
|
|
tensor_name = op_name + "/bias";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
auto parameter = func_graph_->add_parameter();
|
|
|
|
auto parameter = func_graph_->add_parameter();
|
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get());
|
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name);
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
@ -248,11 +267,12 @@ STATUS TfliteModelParser::ConvertGraphInputs() {
|
|
|
|
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
|
|
|
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
|
|
|
parameter->set_abstract(abstract_tensor);
|
|
|
|
parameter->set_abstract(abstract_tensor);
|
|
|
|
parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter");
|
|
|
|
parameter->set_name("graph_input-" + std::to_string(tflite_graph_input));
|
|
|
|
nodes_.insert(std::pair(tflite_graph_input, parameter));
|
|
|
|
nodes_.insert(std::pair(tflite_graph_input, parameter));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return RET_OK;
|
|
|
|
return RET_OK;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteModelParser::ConvertGraphOutputs() {
|
|
|
|
STATUS TfliteModelParser::ConvertGraphOutputs() {
|
|
|
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
|
|
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
|
|
|
if (tflite_subgraph->outputs.size() > 1) {
|
|
|
|
if (tflite_subgraph->outputs.size() > 1) {
|
|
|
@ -312,7 +332,8 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
|
|
|
return RET_OK;
|
|
|
|
return RET_OK;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) {
|
|
|
|
STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter,
|
|
|
|
|
|
|
|
const std::string &tensor_name) {
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "tensor is null, get const tensor failed.";
|
|
|
|
MS_LOG(ERROR) << "tensor is null, get const tensor failed.";
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
return RET_NULL_PTR;
|
|
|
@ -329,7 +350,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
|
|
|
|
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
|
|
|
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
|
|
|
parameter->set_abstract(abstract_tensor);
|
|
|
|
parameter->set_abstract(abstract_tensor);
|
|
|
|
parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter");
|
|
|
|
parameter->set_name(tensor_name);
|
|
|
|
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
|
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
|
|
|
MS_ASSERT(param_value != nullptr);
|
|
|
|
MS_ASSERT(param_value != nullptr);
|
|
|
|