modify tf_converter & add some tf parsers

pull/8915/head
wangzhe 4 years ago
parent 5635e40b7c
commit a6fde99750

@ -114,7 +114,7 @@ endif ()
file(GLOB PROTO_FILE ""
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto)
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
add_library(proto_mid OBJECT ${PROTO_SRCS})

@ -28,6 +28,7 @@
#include "parser/caffe/caffe_converter.h"
#include "parser/tflite/tflite_converter.h"
#include "parser/onnx/onnx_converter.h"
#include "parser/tf/tf_converter.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/anf_importer/import_from_protobuf.h"
#include "proto/onnx.pb.h"
@ -149,6 +150,10 @@ int RunConverter(int argc, const char **argv) {
OnnxConverter onnxConverter;
fb_graph = onnxConverter.Convert(flags.get());
} break;
case FmkType::FmkType_TF: {
TFConverter tfConverter;
fb_graph = tfConverter.Convert(flags.get());
} break;
default: {
MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " "
<< GetErrorInfo(RET_INPUT_PARAM_INVALID);

@ -126,8 +126,10 @@ int Flags::Init(int argc, const char **argv) {
this->fmk = FmkType_TFLITE;
} else if (this->fmkIn == "ONNX") {
this->fmk = FmkType_ONNX;
} else if (this->fmkIn == "TF") {
this->fmk = FmkType_TF;
} else {
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX";
std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX";
return RET_INPUT_PARAM_INVALID;
}

@ -44,6 +44,7 @@ class ModelParser {
return func_graph;
}
protected:
virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) = 0;

@ -34,10 +34,10 @@ class CaffeModelParser : public ModelParser {
virtual ~CaffeModelParser();
private:
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
private:
STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);
STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);

@ -45,12 +45,12 @@ class OnnxModelParser : public ModelParser {
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
const QuantType &quantType);
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
private:
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph);

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_activation_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ActivationParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (tf_op.op() == "Relu") {
attr->type = schema::ActivationType_RELU;
} else if (tf_op.op() == "Relu6") {
attr->type = schema::ActivationType_RELU6;
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
}
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFActivationParser : public TFNodeParser {
public:
TFActivationParser() = default;
~TFActivationParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_

@ -0,0 +1,93 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_arithmetic_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ArithmeticParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
if (tf_op.op() == "Add") {
auto attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Add;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Sub") {
auto attr = std::make_unique<schema::SubT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Sub;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Mul") {
auto attr = std::make_unique<schema::MulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Mul;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Div") {
auto attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Div;
primitive->value.value = attr.release();
}
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser());
TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser());
TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser());
TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFArithmeticParser : public TFNodeParser {
public:
TFArithmeticParser() = default;
~TFArithmeticParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_biasadd_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF BiasAddParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
attr->axis = {1};
primitive->value.type = schema::PrimitiveType_Add;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFBiasAddParser : public TFNodeParser {
public:
TFBiasAddParser() = default;
~TFBiasAddParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_

@ -0,0 +1,22 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_converter.h"
#include "tools/converter/parser/tf/tf_model_parser.h"
namespace mindspore {
namespace lite {
TFConverter::TFConverter() { modelParser = new TFModelParser(); }
} // namespace lite
} // namespace mindspore

@ -13,22 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_add_parser.h"
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
#include <string>
#include <memory>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/converter.h"
namespace mindspore {
namespace lite {
STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
PrimitiveC *primitiveC, int *output_size) {
auto attr = std::make_unique<schema::PrimitiveT>();
attr->value.type = schema::PrimitiveType_Add;
primitiveC = PrimitiveC::Create(attr.release());
MS_LOG(INFO) << "primitive name" << primitiveC->type_name();
return RET_OK;
}
TFNodeRegistrar g_tfAddParser("Add", new TFAddParser());
class TFConverter : public Converter {
public:
TFConverter();
~TFConverter() = default;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_matmul_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF MatMulParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) {
attr->transposeA = attr_value.b();
}
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) {
attr->transposeB = attr_value.b();
}
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser());
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save