diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 25c4971359..4afe5d22db 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -114,7 +114,7 @@ endif () file(GLOB PROTO_FILE "" ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto - ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto + ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) add_library(proto_mid OBJECT ${PROTO_SRCS}) diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 13e6b1d1e1..12a8a9269e 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -28,6 +28,7 @@ #include "parser/caffe/caffe_converter.h" #include "parser/tflite/tflite_converter.h" #include "parser/onnx/onnx_converter.h" +#include "parser/tf/tf_converter.h" #include "tools/anf_exporter/anf_exporter.h" #include "tools/anf_importer/import_from_protobuf.h" #include "proto/onnx.pb.h" @@ -149,6 +150,10 @@ int RunConverter(int argc, const char **argv) { OnnxConverter onnxConverter; fb_graph = onnxConverter.Convert(flags.get()); } break; + case FmkType::FmkType_TF: { + TFConverter tfConverter; + fb_graph = tfConverter.Convert(flags.get()); + } break; default: { MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " << GetErrorInfo(RET_INPUT_PARAM_INVALID); diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 7cd1543f52..40ccc61209 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -126,8 +126,10 @@ int Flags::Init(int argc, const char **argv) { this->fmk = FmkType_TFLITE; } else if (this->fmkIn == "ONNX") { this->fmk = FmkType_ONNX; + } else if (this->fmkIn == "TF") { + this->fmk = FmkType_TF; } else { - std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX"; + std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX"; return RET_INPUT_PARAM_INVALID; } diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 35592c450e..3c223abe7f 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -44,6 +44,7 @@ class ModelParser { return func_graph; } + protected: virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type = QuantType_QUANT_NONE) = 0; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 7c12315d05..c5c8c5571a 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -34,10 +34,10 @@ class CaffeModelParser : public ModelParser { virtual ~CaffeModelParser(); + private: schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type = QuantType_QUANT_NONE) override; - private: STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index df68a25c9e..265d434587 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -45,12 +45,12 @@ class OnnxModelParser : public ModelParser { int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, const QuantType &quantType); - schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type = QuantType_QUANT_NONE) override; - static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); private: + schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type = QuantType_QUANT_NONE) override; + std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); diff --git a/mindspore/lite/tools/converter/parser/tf/attr_value.proto b/mindspore/lite/tools/converter/parser/tf/proto/attr_value.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/attr_value.proto rename to mindspore/lite/tools/converter/parser/tf/proto/attr_value.proto diff --git a/mindspore/lite/tools/converter/parser/tf/function.proto b/mindspore/lite/tools/converter/parser/tf/proto/function.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/function.proto rename to mindspore/lite/tools/converter/parser/tf/proto/function.proto diff --git a/mindspore/lite/tools/converter/parser/tf/graph.proto b/mindspore/lite/tools/converter/parser/tf/proto/graph.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/graph.proto rename to mindspore/lite/tools/converter/parser/tf/proto/graph.proto diff --git a/mindspore/lite/tools/converter/parser/tf/node_def.proto b/mindspore/lite/tools/converter/parser/tf/proto/node_def.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/node_def.proto rename to mindspore/lite/tools/converter/parser/tf/proto/node_def.proto diff --git a/mindspore/lite/tools/converter/parser/tf/op_def.proto b/mindspore/lite/tools/converter/parser/tf/proto/op_def.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/op_def.proto rename to mindspore/lite/tools/converter/parser/tf/proto/op_def.proto diff --git a/mindspore/lite/tools/converter/parser/tf/resource_handle.proto b/mindspore/lite/tools/converter/parser/tf/proto/resource_handle.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/resource_handle.proto rename to mindspore/lite/tools/converter/parser/tf/proto/resource_handle.proto diff --git a/mindspore/lite/tools/converter/parser/tf/tensor.proto b/mindspore/lite/tools/converter/parser/tf/proto/tensor.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/tensor.proto rename to mindspore/lite/tools/converter/parser/tf/proto/tensor.proto diff --git a/mindspore/lite/tools/converter/parser/tf/tensor_shape.proto b/mindspore/lite/tools/converter/parser/tf/proto/tensor_shape.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/tensor_shape.proto rename to mindspore/lite/tools/converter/parser/tf/proto/tensor_shape.proto diff --git a/mindspore/lite/tools/converter/parser/tf/types.proto b/mindspore/lite/tools/converter/parser/tf/proto/types.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/types.proto rename to mindspore/lite/tools/converter/parser/tf/proto/types.proto diff --git a/mindspore/lite/tools/converter/parser/tf/versions.proto b/mindspore/lite/tools/converter/parser/tf/proto/versions.proto similarity index 100% rename from mindspore/lite/tools/converter/parser/tf/versions.proto rename to mindspore/lite/tools/converter/parser/tf/proto/versions.proto diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc new file mode 100644 index 0000000000..f32bb15b86 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_activation_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ActivationParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + if (tf_op.op() == "Relu") { + attr->type = schema::ActivationType_RELU; + } else if (tf_op.op() == "Relu6") { + attr->type = schema::ActivationType_RELU6; + } else { + MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); + } + + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); +TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h new file mode 100644 index 0000000000..0c04e4744c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFActivationParser : public TFNodeParser { + public: + TFActivationParser() = default; + ~TFActivationParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc new file mode 100644 index 0000000000..9ff2afa129 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_arithmetic_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ArithmeticParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + + if (tf_op.op() == "Add") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Add; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "Sub") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Sub; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "Mul") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Mul; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "Div") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Div; + primitive->value.value = attr.release(); + } + + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser()); +TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); +TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); +TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h new file mode 100644 index 0000000000..6b02b7e63d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFArithmeticParser : public TFNodeParser { + public: + TFArithmeticParser() = default; + ~TFArithmeticParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc new file mode 100644 index 0000000000..b749946b1c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_biasadd_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF BiasAddParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + attr->axis = {1}; + + primitive->value.type = schema::PrimitiveType_Add; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.h new file mode 100644 index 0000000000..12fa610a01 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFBiasAddParser : public TFNodeParser { + public: + TFBiasAddParser() = default; + ~TFBiasAddParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_converter.cc b/mindspore/lite/tools/converter/parser/tf/tf_converter.cc new file mode 100644 index 0000000000..13aff62310 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_converter.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_converter.h" +#include "tools/converter/parser/tf/tf_model_parser.h" +namespace mindspore { +namespace lite { +TFConverter::TFConverter() { modelParser = new TFModelParser(); } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_converter.h similarity index 55% rename from mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc rename to mindspore/lite/tools/converter/parser/tf/tf_converter.h index a04cdbec57..6e1a685c05 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_converter.h @@ -13,22 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include "tools/converter/parser/tf/tf_add_parser.h" +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ #include #include -#include "tools/converter/parser/tf/tf_node_parser_registry.h" - +#include "tools/converter/converter.h" namespace mindspore { namespace lite { -STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr &tf_model, - PrimitiveC *primitiveC, int *output_size) { - auto attr = std::make_unique(); - attr->value.type = schema::PrimitiveType_Add; - primitiveC = PrimitiveC::Create(attr.release()); - MS_LOG(INFO) << "primitive name" << primitiveC->type_name(); - return RET_OK; -} -TFNodeRegistrar g_tfAddParser("Add", new TFAddParser()); +class TFConverter : public Converter { + public: + TFConverter(); + + ~TFConverter() = default; +}; } // namespace lite } // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc new file mode 100644 index 0000000000..a0660f7301 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_matmul_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF MatMulParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + tensorflow::AttrValue attr_value; + if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) { + attr->transposeA = attr_value.b(); + } + if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) { + attr->transposeB = attr_value.b(); + } + + primitive->value.type = schema::PrimitiveType_MatMul; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.h new file mode 100644 index 0000000000..8335b96fa7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFMatMulParser : public TFNodeParser { + public: + TFMatMulParser() = default; + ~TFMatMulParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index f41474f09c..447d86db83 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -16,36 +16,236 @@ */ #include "tools/converter/parser/tf/tf_model_parser.h" -#include -#include +#include +#include +#include "src/common/utils.h" #include "src/common/log_adapter.h" -#include "tools/converter/parser/tf/tf_util.h" #include "tools/common/graph_util.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "src/param_value_lite.h" +#include "tools/common/protobuf_utils.h" namespace mindspore { namespace lite { +static const std::unordered_map TF_TYPE_MAP = { + {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, + {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, + {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, + {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, + {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, + {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}}; + +TypeId GetTFDataType(const tensorflow::DataType &tf_data_type) { + auto iter = TF_TYPE_MAP.find(tf_data_type); + if (iter == TF_TYPE_MAP.end()) { + MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; + return kTypeUnknown; + } + return iter->second; +} + +AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { + AnfNodePtr ret = nullptr; + if (anf_node_map.find(name) != anf_node_map.end()) { + ret = anf_node_map[name]; + } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { + ret = anf_node_map[name + ":0"]; + } + return ret; +} + +std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { + if (node.op() != "Identity" && node.op() != "StopGradient") { + return node.name(); + } + auto tmp_node = &node; + while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { + tmp_node = tf_node_map[tmp_node->input(0)]; + } + return tmp_node->name(); +} + +STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, + const ParameterPtr ¶meter, std::vector *shape_vector) { + MS_ASSERT(parameter != nullptr); + MS_ASSERT(shape_vector != nullptr); + const tensorflow::TensorProto &tensor_proto = attr_value.tensor(); + const tensorflow::TensorShapeProto &tensor_shape = tensor_proto.tensor_shape(); + int shape_size = 1; + shape_vector->clear(); + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_vector->push_back(tensor_shape.dim(i).size()); + shape_size *= tensor_shape.dim(i).size(); + } + + int tensor_size; + auto param_value = std::make_shared(); + if (param_value == nullptr) { + MS_LOG(ERROR) << "param_value is nullptr"; + return RET_ERROR; + } + if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { + auto tensor_data = new (std::nothrow) float[shape_size]; + if (tensor_proto.float_val_size() == 1) { + float value = tensor_proto.float_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + if (tensor_proto.tensor_content().size() == shape_size * sizeof(float)) { + const auto addr = reinterpret_cast(tensor_proto.tensor_content().data()); + auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + } + param_value->set_tensor_addr(tensor_data); + tensor_size = shape_size * sizeof(float); + } else if (type == kNumberTypeInt32) { + auto tensor_data = new (std::nothrow) int[shape_size]; + if (tensor_proto.int_val_size() == 1) { + int value = tensor_proto.int_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + if (tensor_proto.tensor_content().size() == shape_size * sizeof(int32_t)) { + const auto addr = reinterpret_cast(tensor_proto.tensor_content().data()); + auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + } + param_value->set_tensor_addr(tensor_data); + tensor_size = shape_size * sizeof(int); + } else if (type == kNumberTypeBool) { + auto tensor_data = new (std::nothrow) int[shape_size]; + if (tensor_proto.bool_val_size() == 1) { + int value = tensor_proto.bool_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + param_value->set_tensor_addr(tensor_data); + tensor_size = shape_size * sizeof(int); + } else { + MS_LOG(ERROR) << "Unsupport dataType: " << type; + return RET_ERROR; + } + + std::vector param_shape(shape_vector->begin(), shape_vector->end()); + param_value->set_tensor_shape(param_shape); + param_value->set_tensor_type(type); + param_value->set_tensor_size(tensor_size); + param_value->set_format(schema::Format::Format_NHWC); + parameter->set_default_param(param_value); + parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); + return RET_OK; +} + +STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter) { + MS_ASSERT(node != nullptr); + MS_ASSERT(parameter != nullptr); + + tensorflow::AttrValue attr_value; + TypeId type = kNumberTypeFloat32; + if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) { + type = GetTFDataType(attr_value.type()); + } + auto type_ptr = TypeIdToType(type); + + std::vector shape; + if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { + auto &shape_attr = attr_value.shape(); + for (int i = 0; i < shape_attr.dim_size(); ++i) { + shape.push_back(shape_attr.dim(i).size()); + } + } + std::vector shape_vector(shape.begin(), shape.end()); + + if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { + MS_LOG(INFO) << "Found value attr, means it has default value"; + auto status = ConvertConstTensor(attr_value, type, parameter, &shape_vector); + if (status != RET_OK) { + return status; + } + } else { + parameter->set_name("placeholder_" + std::to_string(anf_node_map.size())); + graph_input_names.emplace_back(parameter->name()); + } + + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "abstract_tensor is nullptr"; + return RET_ERROR; + } + parameter->set_abstract(abstract_tensor); + + anf_node_map[node.name()] = parameter; + return RET_OK; +} + +STATUS TFModelParser::ConvertGraphInputsAndConsts() { + for (auto &pair : tf_node_map) { + bool have_data_depend = false; + for (int i = 0; i < pair.second->input_size(); ++i) { + auto name = pair.second->input(i); + if (!name.empty() && name[0] != '^') { // control_depend input start with "^" + have_data_depend = true; + break; + } + } + if (!have_data_depend) { + auto parameter = funcGraphPtr->add_parameter(); + if (ConvertParameter(*pair.second, parameter) != RET_OK) { + MS_LOG(ERROR) << "convert Parameter Node failed"; + return RET_ERROR; + } + } + } + return RET_OK; +} + FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { - auto status = ValidateFileStr(modelFile, ".prototxt"); + auto status = ValidateFileStr(modelFile, ".pb"); if (status != RET_OK) { - MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; + MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) { + tf_graph_def = std::make_unique(); + if (tf_graph_def == nullptr) { + MS_LOG(ERROR) << "tf_graph_def is nullptr"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get()); + if (status != RET_OK) { MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } funcGraphPtr = std::make_shared(); - status = ConvertGraphInputs(); + if (funcGraphPtr == nullptr) { + MS_LOG(ERROR) << "funGraphPtr is nullptr"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + + for (int i = 0; i < tf_graph_def->node_size(); i++) { + auto &node_def = tf_graph_def->node(i); + tf_node_map[node_def.name()] = &node_def; + } + + status = ConvertGraphInputsAndConsts(); if (status != RET_OK) { - MS_LOG(ERROR) << "Convert graph inputs failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + status = ConvertOps(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert ops failed."; @@ -61,103 +261,36 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin } return funcGraphPtr; } -STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) { - tensorflow::AttrValue attr_value; - if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { - tensorflow::AttrValue data_type; - tensorflow::DataType type = tensorflow::DT_FLOAT; - // datatype - if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) { - type = data_type.type(); - } - const tensorflow::TensorProto &tensorProto = attr_value.tensor(); - const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape(); - parameter = funcGraphPtr->add_parameter(); - std::vector shape_vector; - int shape_size = 1; - shape_vector.resize(tensorShape.dim_size()); - for (int i = 0; i < tensorShape.dim_size(); i++) { - shape_vector[i] = tensorShape.dim(i).size(); - shape_size *= shape_vector[i]; +schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { + MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; + return nullptr; +} + +STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, + const std::vector &input_names, std::vector *inputs) { + // parse inputs + for (size_t j = 0; j < input_names.size(); j++) { + std::string input_name = input_names[j]; // input may be produced by multi-outputs node + if (tf_node_map.find(input_name) != tf_node_map.end()) { + auto input_node = tf_node_map[input_name]; + input_name = GetOriginInputName(*input_node); } - // convert const to paramter - TypePtr ms_data_ype; - auto paramValue = std::make_shared(); - if (type == tensorflow::DT_FLOAT) { - ms_data_ype = kFloat32; - auto tensor_data = new (std::nothrow) float[shape_size]; - if (tensorProto.float_val_size() == 1) { - float value = tensorProto.float_val(0); - for (int i = 0; i < shape_size; i++) { - tensor_data[i] = value; - } - } - if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) { - const auto addr = reinterpret_cast(tensorProto.tensor_content().data()); - auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; - } - } - paramValue->set_tensor_addr(tensor_data); - paramValue->set_tensor_size(shape_size * sizeof(float)); - } else if (type == tensorflow::DT_INT32) { - ms_data_ype = kInt32; - auto tensor_data = new (std::nothrow) int[shape_size]; - if (tensorProto.int_val_size() == 1) { - int value = tensorProto.int_val(0); - for (int i = 0; i < shape_size; i++) { - tensor_data[i] = value; - } - } - if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) { - const auto addr = reinterpret_cast(tensorProto.tensor_content().data()); - auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; - } - } - paramValue->set_tensor_addr(tensor_data); - paramValue->set_tensor_size(shape_size * sizeof(int)); - } else if (type == tensorflow::DT_BOOL) { - ms_data_ype = kFloat32; - auto tensor_data = new (std::nothrow) int[shape_size]; - if (tensorProto.bool_val_size() == 1) { - int value = tensorProto.bool_val(0); - for (int i = 0; i < shape_size; i++) { - tensor_data[i] = value; - } - } - paramValue->set_tensor_addr(tensor_data); - paramValue->set_tensor_size(shape_size * sizeof(int)); - } else { - MS_LOG(ERROR) << "Unsupport dataType," << node->name(); + auto input = GetAnfNode(input_name); + if (input == nullptr) { + MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; return RET_ERROR; } - auto abstract_tensor = std::make_shared(ms_data_ype, shape_vector); - parameter->set_abstract(abstract_tensor); - parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); - - std::vector param_shape; - (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape), - [](const int64_t &value) { return static_cast(value); }); - - MS_ASSERT(paramValue != nullptr); - paramValue->set_tensor_shape(param_shape); - paramValue->set_tensor_type(ms_data_ype->type_id()); - paramValue->set_format(schema::Format::Format_NHWC); - paramValue->set_tensor_size(shape_size * sizeof(int)); - parameter->set_default_param(paramValue); + inputs->emplace_back(input); } return RET_OK; } -STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) { + +STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) { if (output_size == 1) { std::vector shape_vector; anf_node->set_abstract(std::make_shared(kFloat32, shape_vector)); - anf_node_map.insert(std::pair(op->name(), anf_node)); + anf_node_map.insert(std::pair(op.name(), anf_node)); } else { AbstractBasePtrList abstractList; for (int output_idx = 0; output_idx < output_size; output_idx++) { @@ -174,113 +307,126 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const C CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); getItemCNode->set_fullname_with_scope(output_item_name); - anf_node_map.insert(std::pair(output_item_name, getItemCNode)); + anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); } anf_node->set_abstract(std::make_shared(abstractList)); } return RET_OK; } + STATUS TFModelParser::ConvertOps() { - NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW"); + NoSupportOp::GetInstance()->SetFmkType("TF"); STATUS status = RET_OK; - - // redirect identity to it's input0 - ClipIdentityAndStopGradient(); int op_idx = 0; for (int i = 0; i < tf_graph_def->node_size(); i++) { - auto node_def = tf_graph_def->mutable_node(i); - tf_node_map[node_def->name()] = node_def; - auto tf_op_type = node_def->op(); - if (tf_op_type == "Placeholder" || tf_op_type == "Const") { + auto &node_def = tf_graph_def->node(i); + const auto &op_type = node_def.op(); + if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { continue; } - auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type); + auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); if (node_parser == nullptr) { - NoSupportOp::GetInstance()->InsertOp(tf_op_type); + NoSupportOp::GetInstance()->InsertOp(op_type); status = (status == RET_OK ? RET_NOT_FIND_OP : status); - MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type; + MS_LOG(ERROR) << "cannot find node parser:" << op_type; + continue; + } + if (status != RET_OK) { continue; } PrimitiveC *primitiveC = nullptr; - if (status == RET_OK) { - int output_size = 1; - status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size); - if (status != RET_OK) { - MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed"; - continue; - } - std::vector opInputs = {NewValueNode(std::shared_ptr(primitiveC))}; - // parse inputs - for (int j = 0; j < node_def->input_size(); j++) { - auto input_node = tf_node_map[node_def->input(i)]; - // last node output - if (anf_node_map.find(input_node->name()) != anf_node_map.end()) { - opInputs.emplace_back(anf_node_map[input_node->name()]); - continue; - } - // const tensor - if (input_node->op() == "Const") { - ParameterPtr parameter; - if (ConvertConstTensor(input_node, parameter) != RET_OK) { - MS_LOG(ERROR) << "convert const tensor failed," << input_node->name(); - return RET_ERROR; - } - opInputs.emplace_back(parameter); - anf_node_map[parameter->fullname_with_scope()] = parameter; - continue; - } - MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor."; - return RET_ERROR; - } - auto anf_node = funcGraphPtr->NewCNode(opInputs); - anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++)); + int output_size; + std::vector input_names; + status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "node " << op_type << " parser failed"; + continue; + } - // parse outputs - status = ConvertOutputTensor(node_def, anf_node, output_size); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; - } + auto value_node = NewValueNode(std::shared_ptr(primitiveC)); + if (value_node == nullptr) { + MS_LOG(ERROR) << "value_node is nullptr"; + status = RET_ERROR; + continue; + } + std::vector inputs = {value_node}; + status = ConvertInputNodes(node_def, input_names, &inputs); + if (status != RET_OK) { + continue; + } + // control_depends are not processed currently + auto anf_node = funcGraphPtr->NewCNode(inputs); + anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++)); + + status = ConvertOutputTensor(node_def, anf_node, output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + continue; } - // redirect identity to it's input0 - ClipIdentityAndStopGradient(); } - return RET_OK; + return status; } -STATUS TFModelParser::ConvertGraphInputs() { - for (int i = 0; i < tf_graph_def->node_size(); i++) { - auto node_def = tf_graph_def->mutable_node(i); - tf_node_map[node_def->name()] = node_def; - if (node_def->op() == "Placeholder") { - auto parameter = funcGraphPtr->add_parameter(); - if (ConvertConstTensor(node_def, parameter) != RET_OK) { - MS_LOG(ERROR) << "convert const tensor failed"; + +STATUS TFModelParser::ConvertGraphOutputs() { + // because output of intermediate node in anf graph may also be output tensors, we search output tensors in + // tf_node_map but not anf_node_map + std::set all_node_inputs; + std::vector output_nodes; + for (auto &pair : tf_node_map) { + for (int i = 0; i < pair.second->input_size(); ++i) { + all_node_inputs.insert(pair.second->input(i)); + } + } + for (auto &pair : tf_node_map) { + auto it = all_node_inputs.find(pair.first); + if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity + auto origin_name = GetOriginInputName(*(pair.second)); + auto anf_node = GetAnfNode(origin_name); + if (anf_node == nullptr) { + MS_LOG(ERROR) << "can't find anf node"; return RET_ERROR; } - anf_node_map[node_def->name()] = parameter; - graph_input_names.emplace_back(node_def->name()); + output_nodes.push_back(anf_node); + graph_output_names.push_back(anf_node->fullname_with_scope()); } } - return RET_OK; -} -STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; } -std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { - if (node.op() != "Identity" && node.op() != "StopGradient") { - return node.name(); - } - auto tmpNode = node; - while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") { - tmpNode = *tf_node_map[tmpNode.input(0)]; - } - return tmpNode.name(); -} + if (output_nodes.size() > 1) { + std::vector &make_tuple_inputs = output_nodes; + auto make_tuple_prim_ptr = GetMakeTuplePrim(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; + return RET_NULL_PTR; + } + auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); + make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim); + auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope("return tuple"); -void TFModelParser::ClipIdentityAndStopGradient() { - for (auto &pair : tf_node_map) { - pair.second = tf_node_map[GetOriginInputName(*pair.second)]; + auto return_prim_ptr = GetReturnPrim(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return RET_NULL_PTR; + } + auto value_node = NewValueNode(return_prim_ptr); + std::vector op_inputs = {value_node, make_tuple_cnode}; + auto cnode = funcGraphPtr->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + funcGraphPtr->set_return(cnode); + } else { + auto return_prim_ptr = GetReturnPrim(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return RET_NULL_PTR; + } + auto value_node = NewValueNode(return_prim_ptr); + std::vector op_inputs{value_node, output_nodes.front()}; + auto return_cnode = funcGraphPtr->NewCNode(op_inputs); + return_cnode->set_fullname_with_scope("return"); + funcGraphPtr->set_return(return_cnode); } + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 1d3229c19c..67b3f50618 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -31,29 +31,36 @@ namespace mindspore { namespace lite { -class TFModelParser { +class TFModelParser : public ModelParser { public: TFModelParser() = default; ~TFModelParser() = default; FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); + protected: + schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) override; + private: - STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter); - STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size); + AnfNodePtr GetAnfNode(const std::string &name); + std::string GetOriginInputName(const tensorflow::NodeDef &node); + STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, + std::vector *shape_vector); + STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter); + STATUS ConvertGraphInputsAndConsts(); + STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector &input_names, + std::vector *inputs); + STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size); STATUS ConvertOps(); - STATUS ConvertGraphInputs(); STATUS ConvertGraphOutputs(); - std::string GetOriginInputName(const tensorflow::NodeDef &node); - - void ClipIdentityAndStopGradient(); - FuncGraphPtr funcGraphPtr; std::unique_ptr tf_graph_def; std::map tf_node_map; std::unordered_map anf_node_map; - std::vector graph_input_names, graphOutputNames; + std::vector graph_input_names; + std::vector graph_output_names; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_add_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc similarity index 61% rename from mindspore/lite/tools/converter/parser/tf/tf_add_parser.h rename to mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc index 00ffb2989a..71c8da5a45 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_add_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc @@ -13,23 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H - -#include #include "tools/converter/parser/tf/tf_node_parser.h" +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" namespace mindspore { namespace lite { -class TFAddParser : public TFNodeParser { - public: - TFAddParser() = default; - ~TFAddParser() override = default; - - STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr &tf_model, - PrimitiveC *primitiveC, int *output_size) override; -}; +STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector *inputs) { + if (tf_op.input_size() <= idx) { + MS_LOG(ERROR) << "input idx is greater than op input size"; + return RET_PARAM_INVALID; + } + inputs->push_back(tf_op.input(idx)); + return RET_OK; +} } // namespace lite } // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h index 804f31e59d..d6d3abf99b 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H #include +#include #include #include #include "tools/converter/parser/tf/tf_util.h" @@ -32,12 +33,14 @@ class TFNodeParser { virtual ~TFNodeParser() = default; - virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr &tf_model, - PrimitiveC *primitiveC, int *output_size) { + virtual STATUS Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { return RET_OK; } + + STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector *inputs); }; } // namespace lite } // namespace mindspore - #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc new file mode 100644 index 0000000000..7f3912b87f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_split_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SplitParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "num_split", &attr_value)) { + MS_LOG(ERROR) << "The attribute num_split should be specified"; + return RET_PARAM_INVALID; + } + attr->numberSplit = (int32_t)(attr_value.i()); + + int split_dim_index; + int input_index; + if (tf_op.op() == "Split") { + split_dim_index = 0; + input_index = 1; + } else { + split_dim_index = 2; + input_index = 0; + } + + if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Split input split_dim node failed"; + return RET_ERROR; + } + const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index)); + if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The attribute splitDim should be specified"; + return RET_PARAM_INVALID; + } + auto split_dim_tensor = attr_value.tensor(); + attr->splitDim = split_dim_tensor.int_val(0); + *output_size = attr->numberSplit; + + if (tf_op.op() == "SplitV") { + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Split input size_splits failed"; + return RET_ERROR; + } + auto size_splits_node = tf_node_map.at(tf_op.input(1)); + if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The attribute size splits should be specified"; + return RET_PARAM_INVALID; + } + auto size_splits_tensor = attr_value.tensor(); + auto size = size_splits_tensor.tensor_content().size() / sizeof(int32_t); + attr->sizeSplits.resize(size); + auto ret = memcpy_s(attr->sizeSplits.data(), size * sizeof(int32_t), size_splits_tensor.tensor_content().data(), + size * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + } + + primitive->value.type = schema::PrimitiveType_Split; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + auto status = AddOpInput(tf_op, input_index, inputs); + return status; +} +TFNodeRegistrar g_tfSplitParser("Split", new TFSplitParser()); +TFNodeRegistrar g_tfSplitVParser("SplitV", new TFSplitParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_split_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_split_parser.h new file mode 100644 index 0000000000..3ecefb9bd9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_split_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSplitParser : public TFNodeParser { + public: + TFSplitParser() = default; + ~TFSplitParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index c7cb8fc161..a2a4a7498a 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -22,9 +22,9 @@ namespace mindspore { namespace lite { -bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, +bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, tensorflow::AttrValue *attr_value) { - const google::protobuf::Map &attr = nodeDef->attr(); + const google::protobuf::Map &attr = nodeDef.attr(); const google::protobuf::Map::const_iterator it = attr.find(attr_name); if (it != attr.end()) { *attr_value = it->second; @@ -32,24 +32,5 @@ bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::str } return false; } - -bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) { - std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); - if (!fs.is_open()) { - fprintf(stderr, "open failed %s\n", filepath); - return false; - } - - google::protobuf::io::IstreamInputStream input(&fs); - google::protobuf::io::CodedInputStream codedstr(&input); - - codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); - - bool success = message->ParseFromCodedStream(&codedstr); - - fs.close(); - - return success; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.h b/mindspore/lite/tools/converter/parser/tf/tf_util.h index 21888388f7..0c60defe58 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.h @@ -26,10 +26,8 @@ namespace mindspore { namespace lite { class TensorFlowUtils { public: - static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, + static bool FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, tensorflow::AttrValue *attr_value); - - static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message); }; } // namespace lite } // namespace mindspore