From 88ba8ee449c8decc14149d1f4eb4fd0ca114f6ab Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Wed, 29 Jul 2020 22:10:15 +0800 Subject: [PATCH] add deconv parser --- .../parser/tflite/tflite_deconv_parser.cc | 68 +++++++++++++++++++ .../parser/tflite/tflite_deconv_parser.h | 41 +++++++++++ .../parser/tflite/tflite_model_parser.cc | 17 +++-- .../parser/tflite/tflite_model_parser.h | 3 +- .../converter/parser/tflite/tflite_util.cc | 1 + 5 files changed, 124 insertions(+), 6 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc new file mode 100644 index 0000000000..c8e970ea2a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_deconv_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_op_set, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; + std::unique_ptr attr(new schema::DeConv2DT()); + const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + return RET_NULL_PTR; + } + attr->group = 1; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = 1; + attr->dilateW = 1; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format_NHWC; + // get the conv op weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tflite_model_buffer, tensor_cache, schema::Format_KHWC)) { + return RET_ERROR; + } + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[KHWC_C]; + attr->channelOut = weight_shape[KHWC_K]; + attr->kernelW = weight_shape[KHWC_W]; + attr->kernelH = weight_shape[KHWC_H]; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeConv2D; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h new file mode 100644 index 0000000000..46e7e1b8b6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_DECONV_PARSER_H +#define PREDICT_TFLITE_DECONV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDeConvParser : public TfliteNodeParser { + public: + TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_op_set, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_DECONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 88a08fd0dc..28cf1f1bb7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -112,10 +112,17 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, +STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, const std::unique_ptr &tflite_op, TensorCache *tensorCache) { - for (const auto &tfliteIndex : tflite_op->inputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[tfliteIndex]; + auto op_type = GetTfliteNodeType(tflite_op, tflite_model); + std::vector op_inputs(tflite_op->inputs); + if (op_type == "DeConv2D") { + reverse(op_inputs.begin(), op_inputs.end()); + } + + for (const auto &tflite_index : op_inputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; auto tensor_name = tflite_tensor->name; auto op = tfliteOpMap[tflite_op.get()]; unsigned int index = tensorCache->FindTensor(tensor_name); @@ -228,8 +235,8 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st } for (const auto &tflite_op : tflite_subgraph->operators) { - auto statusTmp = SetOpInputIdx(tflite_subgraph, tflite_op, &tensorCache); - if (statusTmp != RET_OK) { + auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); + if (status_tmp != RET_OK) { // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); } } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 20c5a73f48..00b0520946 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -73,7 +73,8 @@ class TfliteModelParser : public ModelParser { schema::CNodeT *op, TensorCache *tensorCache); - STATUS SetOpInputIdx(const std::unique_ptr &tflite_subgraph, + STATUS SetOpInputIdx(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, const std::unique_ptr &tflite_op, TensorCache *tensorCache); std::map opMap; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 9b9532c9ff..8331644104 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -55,6 +55,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, {tflite::BuiltinOperator_SQUARED_DIFFERENCE, "SquaredDifference"}, {tflite::BuiltinOperator_FAKE_QUANT, "FakeQuant"}, + {tflite::BuiltinOperator_TRANSPOSE_CONV, "DeConv2D"}, }; std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) {