|
|
|
@ -81,6 +81,20 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
|
|
|
|
|
return func_graph_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
|
|
|
|
|
std::string tensor_name = op_name + "/input-" + std::to_string(index);
|
|
|
|
|
if (op_type == tflite::BuiltinOperator_CONV_2D || op_type == tflite::BuiltinOperator_TRANSPOSE_CONV ||
|
|
|
|
|
op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
|
|
|
|
|
if (index == 1) {
|
|
|
|
|
tensor_name = op_name + "/weight";
|
|
|
|
|
}
|
|
|
|
|
if (index == 2) {
|
|
|
|
|
tensor_name = op_name + "/bias";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return tensor_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
|
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
|
|
|
@ -136,18 +150,7 @@ STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
|
if (!input_tensor->name.empty()) {
|
|
|
|
|
tensor_name = input_tensor->name;
|
|
|
|
|
} else {
|
|
|
|
|
tensor_name = op_name + "/input-" + std::to_string(op_inputs.size());
|
|
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV ||
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
|
|
|
|
|
if (i == 1) {
|
|
|
|
|
tensor_name = op_name + "/weight";
|
|
|
|
|
}
|
|
|
|
|
if (i == 2) {
|
|
|
|
|
tensor_name = op_name + "/bias";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor_name = GetTensorName(i, tflite_op_type, op_name);
|
|
|
|
|
}
|
|
|
|
|
auto parameter = func_graph_->add_parameter();
|
|
|
|
|
status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name);
|
|
|
|
@ -155,18 +158,7 @@ STATUS TfliteModelParser::ConvertOps() {
|
|
|
|
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (tflite_op_type == tflite::BuiltinOperator_CONV_2D ||
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D ||
|
|
|
|
|
tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) {
|
|
|
|
|
if (op_inputs.size() == 2) {
|
|
|
|
|
parameter->set_name(op_name + "/weight");
|
|
|
|
|
} else if (op_inputs.size() == 3) {
|
|
|
|
|
parameter->set_name(op_name + "/bias");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
parameter->set_name(op_name + "/input-" + std::to_string(op_inputs.size() - 1));
|
|
|
|
|
}
|
|
|
|
|
parameter->set_name(tensor_name);
|
|
|
|
|
op_inputs.emplace_back(parameter);
|
|
|
|
|
nodes_.insert(std::pair(input_idx, parameter));
|
|
|
|
|
}
|
|
|
|
@ -364,7 +356,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
|
|
|
|
|
MS_LOG(ERROR) << "parameter is null, get const tensor failed.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const auto &tfliteModelBuffers = tflite_model_->buffers;
|
|
|
|
|
const auto &tflite_model_buffers = tflite_model_->buffers;
|
|
|
|
|
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
|
|
|
|
std::vector<int64_t> shape_vector;
|
|
|
|
|
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
|
|
|
@ -378,7 +370,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
|
|
|
|
|
param_value->set_tensor_shape(tensor->shape);
|
|
|
|
|
param_value->set_tensor_type(GetTfliteDataType(tensor->type));
|
|
|
|
|
param_value->set_format(schema::Format::Format_NHWC);
|
|
|
|
|
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
|
|
|
|
|
const auto &data = tflite_model_buffers.at(tensor->buffer)->data;
|
|
|
|
|
if (!data.empty()) {
|
|
|
|
|
auto size = data.size();
|
|
|
|
|
char *tensor_data = new (std::nothrow) char[size];
|
|
|
|
|