diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc index f32bb15b86..2b2cda5976 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc @@ -46,6 +46,10 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, attr->type = schema::ActivationType_RELU; } else if (tf_op.op() == "Relu6") { attr->type = schema::ActivationType_RELU6; + } else if (tf_op.op() == "Sigmoid") { + attr->type = schema::ActivationType_SIGMOID; + } else if (tf_op.op() == "Tanh") { + attr->type = schema::ActivationType_TANH; } else { MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); } @@ -64,5 +68,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, } TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); +TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); +TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc index 9ff2afa129..fbcb0029e2 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc @@ -69,6 +69,54 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, } primitive->value.type = schema::PrimitiveType_Div; primitive->value.value = attr.release(); + } else if (tf_op.op() == "Maximum") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Maximum; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "Minimum") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Minimum; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "Greater") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Greater; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "GreaterEqual") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_GreaterEqual; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "Less") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Less; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "LessEqual") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_LessEqual; + primitive->value.value = attr.release(); } *primitiveC = PrimitiveC::Create(primitive.release()); @@ -86,8 +134,15 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, return status; } TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser()); +TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser()); TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); +TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser()); +TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser()); +TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); +TFNodeRegistrar g_tfGreaterEqualParser("GreaterEqual", new TFArithmeticParser()); +TFNodeRegistrar g_tfLessParser("Less", new TFArithmeticParser()); +TFNodeRegistrar g_tfLessEqualParser("LessEqual", new TFArithmeticParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc index b749946b1c..cd3434aec2 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc @@ -44,7 +44,7 @@ STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, attr->axis = {1}; - primitive->value.type = schema::PrimitiveType_Add; + primitive->value.type = schema::PrimitiveType_BiasAdd; primitive->value.value = attr.release(); *primitiveC = PrimitiveC::Create(primitive.release()); if (*primitiveC == nullptr) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc new file mode 100644 index 0000000000..7f758eb414 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_cast_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFCastParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF CastParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + auto src_type = TensorFlowUtils::ParseAttrDataType(tf_op, "SrcT"); + if (src_type == kTypeUnknown) { + MS_LOG(ERROR) << "Get attr SrcT failed"; + return RET_ERROR; + } + auto dst_type = TensorFlowUtils::ParseAttrDataType(tf_op, "DstT"); + if (dst_type == kTypeUnknown) { + MS_LOG(ERROR) << "Get attr DstT failed"; + return RET_ERROR; + } + attr->srcT = src_type; + attr->dstT = dst_type; + + primitive->value.type = schema::PrimitiveType_Cast; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfCastParser("Cast", new TFCastParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.h new file mode 100644 index 0000000000..5f0bca39a0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFCastParser : public TFNodeParser { + public: + TFCastParser() = default; + ~TFCastParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc new file mode 100644 index 0000000000..7145c545ec --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_concat_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFConcatParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ConcatParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + if (tf_node_map.find(tf_op.input(tf_op.input_size() - 1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Concat input axis failed"; + return RET_ERROR; + } + auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1)); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + attr->axis = tensor_proto.int_val(0); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "N", &attr_value)) { + MS_LOG(ERROR) << "The N attr should be specified"; + return RET_ERROR; + } + attr->n = (int32_t)attr_value.i(); + + primitive->value.type = schema::PrimitiveType_Concat; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size() - 1; ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfConcatV2Parser("ConcatV2", new TFConcatParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.h new file mode 100644 index 0000000000..ea9fccc142 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFConcatParser : public TFNodeParser { + public: + TFConcatParser() = default; + ~TFConcatParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_base_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_conv_base_parser.cc new file mode 100644 index 0000000000..a775ec8786 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_base_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_conv_base_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "schema/inner/model_generated.h" +namespace mindspore { +namespace lite { +namespace { +const uint32_t STRIDE_DEFAULT_VALUE = 1; +const uint32_t DILATION_DEFAULT_VALUE = 1; +} // namespace +STATUS TFConvBaseParser::ParseStrides(const tensorflow::NodeDef &node_def, const schema::Format &format, + std::vector *strides) { + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(node_def, "strides", &attr_value)) { + strides->at(0) = STRIDE_DEFAULT_VALUE; + strides->at(1) = STRIDE_DEFAULT_VALUE; + } else { + auto stride_list = attr_value.list(); + if (format == schema::Format_NHWC) { + strides->at(0) = stride_list.i(1); + strides->at(1) = stride_list.i(2); + } else { + strides->at(0) = stride_list.i(2); + strides->at(1) = stride_list.i(3); + } + } + return RET_OK; +} + +STATUS TFConvBaseParser::ParseDilations(const tensorflow::NodeDef &node_def, const schema::Format &format, + std::vector *dilations) { + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(node_def, "dilations", &attr_value)) { + dilations->at(0) = DILATION_DEFAULT_VALUE; + dilations->at(1) = DILATION_DEFAULT_VALUE; + } else { + auto dilation_list = attr_value.list(); + if (format == schema::Format_NHWC) { + dilations->at(0) = dilation_list.i(1); + dilations->at(1) = dilation_list.i(2); + } else { + dilations->at(0) = dilation_list.i(2); + dilations->at(1) = dilation_list.i(3); + } + } + return RET_OK; +} + +STATUS TFConvBaseParser::ParseKernels(const tensorflow::NodeDef &node_def, const schema::Format &format, + std::vector *kernel) { + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(node_def, "value", &attr_value)) { + MS_LOG(ERROR) << "The kernels should be specified"; + return RET_PARAM_INVALID; + } + auto shape = attr_value.tensor().tensor_shape(); + if (shape.dim().size() != 4) { + MS_LOG(ERROR) << "Dims of Kernel should be 4."; + return RET_PARAM_INVALID; + } + kernel->at(0) = shape.dim(0).size(); + kernel->at(1) = shape.dim(1).size(); + kernel->at(2) = shape.dim(2).size(); + kernel->at(3) = shape.dim(3).size(); + return RET_OK; +} + +STATUS TFConvBaseParser::ParsePadMode(const tensorflow::NodeDef &node_def, schema::PadMode *pad_mode) { + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(node_def, "padding", &attr_value)) { + MS_LOG(ERROR) << "The attr padding should be specified"; + return RET_PARAM_INVALID; + } + if (attr_value.s() == "VALID") { + *pad_mode = schema::PadMode_VALID; + } else if (attr_value.s() == "SAME") { + *pad_mode = schema::PadMode_SAME_UPPER; + } else { + *pad_mode = schema::PadMode_NOTSET; + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_base_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_conv_base_parser.h new file mode 100644 index 0000000000..d03f4167bc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_base_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFConvBaseParser : public TFNodeParser { + public: + TFConvBaseParser() = default; + ~TFConvBaseParser() override = default; + STATUS ParseStrides(const tensorflow::NodeDef &node_def, const schema::Format &format, std::vector *strides); + STATUS ParseDilations(const tensorflow::NodeDef &node_def, const schema::Format &format, + std::vector *dilations); + STATUS ParseKernels(const tensorflow::NodeDef &node_def, const schema::Format &format, std::vector *kernel); + STATUS ParsePadMode(const tensorflow::NodeDef &node_def, schema::PadMode *pad_mode); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc new file mode 100644 index 0000000000..688792607c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_conv_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/parser/tf/tf_util.h" + +namespace mindspore { +namespace lite { +STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ConvParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + attr->group = 1; + attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); + + std::vector dilations(2); + auto status = ParseDilations(tf_op, attr->format, &dilations); + if (status != RET_OK) { + return status; + } + attr->dilateH = dilations[0]; + attr->dilateW = dilations[1]; + + std::vector strides(2); + status = ParseStrides(tf_op, attr->format, &strides); + if (status != RET_OK) { + return status; + } + attr->strideH = strides[0]; + attr->strideW = strides[1]; + + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Conv2D input weights failed"; + return RET_ERROR; + } + auto weight_node = tf_node_map.at(tf_op.input(1)); + std::vector kernels(4); + status = ParseKernels(*weight_node, attr->format, &kernels); + if (status != RET_OK) { + return status; + } + attr->kernelH = kernels[0]; + attr->kernelW = kernels[1]; + attr->channelIn = kernels[2]; + attr->channelOut = kernels[3]; + + status = ParsePadMode(tf_op, &attr->padMode); + if (status != RET_OK) { + return status; + } + + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); // weights + return status; +} +TFNodeRegistrar g_tfConvParser("Conv2D", new TFConvParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.h new file mode 100644 index 0000000000..ffcc403272 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_conv_base_parser.h" +namespace mindspore { +namespace lite { +class TFConvParser : public TFConvBaseParser { + public: + TFConvParser() = default; + ~TFConvParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc new file mode 100644 index 0000000000..ba8288c424 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_expand_dims_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ExpandDimsParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find ExpandDims input axis failed"; + return RET_ERROR; + } + auto axis_node = tf_node_map.at(tf_op.input(1)); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + attr->dim = tensor_proto.int_val(0); + } else { + attr->dim = (reinterpret_cast(tensor_proto.tensor_content().data()))[0]; + } + + primitive->value.type = schema::PrimitiveType_ExpandDims; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfExpandDimsParser("ExpandDims", new TFExpandDimsParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.h new file mode 100644 index 0000000000..68744c40a1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFExpandDimsParser : public TFNodeParser { + public: + TFExpandDimsParser() = default; + ~TFExpandDimsParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc new file mode 100644 index 0000000000..20c1c670b7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_gather_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF GatherParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (TensorFlowUtils::FindAttrValue(tf_op, "batch_dims", &attr_value)) { + attr->batchDims = attr_value.i(); + } + + bool axis_is_set = false; + if (tf_op.input_size() == 3) { + axis_is_set = true; + if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Gather input axis failed"; + return RET_ERROR; + } + auto axis_node = tf_node_map.at(tf_op.input(2)); + if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.dtype() == tensorflow::DT_INT32) { + if (tensor_proto.int_val_size() > 0) { + attr->axis = tensor_proto.int_val(0); + } else { + attr->axis = (reinterpret_cast(tensor_proto.tensor_content().data()))[0]; + } + } else if (tensor_proto.dtype() == tensorflow::DT_INT64) { + if (tensor_proto.int64_val_size() > 0) { + attr->axis = tensor_proto.int64_val(0); + } else { + attr->axis = (reinterpret_cast(tensor_proto.tensor_content().data()))[0]; + } + } else { + MS_LOG(ERROR) << "axis must be int32 or int64"; + return RET_ERROR; + } + } + if (attr->batchDims != 0 && !axis_is_set) { + attr->axis = attr->batchDims; + } + + primitive->value.type = schema::PrimitiveType_Gather; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfGatherV2Parser("GatherV2", new TFGatherParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.h new file mode 100644 index 0000000000..03e4e31ecd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFGatherParser : public TFNodeParser { + public: + TFGatherParser() = default; + ~TFGatherParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 2205ffa954..e36fce27b3 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -27,22 +27,6 @@ namespace mindspore { namespace lite { -static const std::unordered_map TF_TYPE_MAP = { - {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, - {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, - {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, - {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, - {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, - {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}}; - -TypeId GetTFDataType(const tensorflow::DataType &tf_data_type) { - auto iter = TF_TYPE_MAP.find(tf_data_type); - if (iter == TF_TYPE_MAP.end()) { - MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; - return kTypeUnknown; - } - return iter->second; -} AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { AnfNodePtr ret = nullptr; @@ -151,7 +135,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa tensorflow::AttrValue attr_value; TypeId type = kNumberTypeFloat32; if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) { - type = GetTFDataType(attr_value.type()); + type = TensorFlowUtils::GetTFDataType(attr_value.type()); } auto type_ptr = TypeIdToType(type); @@ -413,7 +397,7 @@ STATUS TFModelParser::ConvertGraphOutputs() { auto cnode = funcGraphPtr->NewCNode(op_inputs); cnode->set_fullname_with_scope("return"); funcGraphPtr->set_return(cnode); - } else { + } else if (output_nodes.size() == 1) { auto return_prim_ptr = GetReturnPrim(); if (return_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetReturnPrim return nullptr"; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc index 71c8da5a45..7af394d659 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc @@ -15,9 +15,7 @@ */ #include "tools/converter/parser/tf/tf_node_parser.h" #include -#include #include -#include "tools/converter/parser/tf/tf_node_parser_registry.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_pack_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_pack_parser.cc new file mode 100644 index 0000000000..5a0d2d872d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_pack_parser.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_pack_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFPackParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF PackParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) { + MS_LOG(ERROR) << "The axis attr should be specified"; + return RET_ERROR; + } + attr->axis = static_cast(attr_value.i()); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "N", &attr_value)) { + MS_LOG(ERROR) << "The axis attr should be specified"; + return RET_ERROR; + } + attr->n = static_cast(attr_value.i()); + + primitive->value.type = schema::PrimitiveType_Stack; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfPackParser("Pack", new TFPackParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_pack_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_pack_parser.h new file mode 100644 index 0000000000..9fa7eaf96b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_pack_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFPackParser : public TFNodeParser { + public: + TFPackParser() = default; + ~TFPackParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc new file mode 100644 index 0000000000..94021ab3ce --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_reduce_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ReduceParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + if (tf_op.op() == "Sum") { + attr->mode = schema::ReduceMode_ReduceSum; + } else if (tf_op.op() == "Max") { + attr->mode = schema::ReduceMode_ReduceMax; + } else if (tf_op.op() == "Min") { + attr->mode = schema::ReduceMode_ReduceMin; + } else if (tf_op.op() == "Mean") { + attr->mode = schema::ReduceMode_ReduceMean; + } else if (tf_op.op() == "Prod") { + attr->mode = schema::ReduceMode_ReduceProd; + } else { + MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op(); + return RET_ERROR; + } + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "keep_dims", &attr_value)) { + MS_LOG(ERROR) << "The keep_dims attr should be specified"; + return RET_ERROR; + } + if (attr_value.value_case() != tensorflow::AttrValue::kB) { + MS_LOG(ERROR) << "the keep_dims attr of reduce should be bool type"; + return RET_ERROR; + } + attr->keepDims = attr_value.b(); + + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Reduce input axis failed"; + return RET_ERROR; + } + auto axis_node = tf_node_map.at(tf_op.input(1)); + if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->axes.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->axes.push_back(data[i]); + } + } + + primitive->value.type = schema::PrimitiveType_Reduce; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfSumParser("Sum", new TFReduceParser()); +TFNodeRegistrar g_tfMaxParser("Max", new TFReduceParser()); +TFNodeRegistrar g_tfMinParser("Min", new TFReduceParser()); +TFNodeRegistrar g_tfMeanParser("Mean", new TFReduceParser()); +TFNodeRegistrar g_tfProdParser("Prod", new TFReduceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.h new file mode 100644 index 0000000000..b1914f21f7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFReduceParser : public TFNodeParser { + public: + TFReduceParser() = default; + ~TFReduceParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.cc new file mode 100644 index 0000000000..a32ff06fff --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_reshape_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ReshapeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + attr->format = schema::Format_NHWC; + // attr->shape is omitted cause input[1] provide shape info + + primitive->value.type = schema::PrimitiveType_Reshape; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfReshapeParser("Reshape", new TFReshapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.h new file mode 100644 index 0000000000..d873c54363 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reshape_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFReshapeParser : public TFNodeParser { + public: + TFReshapeParser() = default; + ~TFReshapeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_round_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_round_parser.cc new file mode 100644 index 0000000000..86a5a8368f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_round_parser.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_round_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFRoundParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF RoundParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Round; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfRoundParser("Round", new TFRoundParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_round_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_round_parser.h new file mode 100644 index 0000000000..229181aa7e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_round_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFRoundParser : public TFNodeParser { + public: + TFRoundParser() = default; + ~TFRoundParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_shape_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_shape_parser.cc new file mode 100644 index 0000000000..9b53470872 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_shape_parser.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_shape_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFShapeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ShapeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Shape; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfShapeParser("Shape", new TFShapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_shape_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_shape_parser.h new file mode 100644 index 0000000000..f65a9c1467 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_shape_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SHAPE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SHAPE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFShapeParser : public TFNodeParser { + public: + TFShapeParser() = default; + ~TFShapeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SHAPE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.cc new file mode 100644 index 0000000000..026b8ec520 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_squeeze_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSqueezeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SqueezeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "squeeze_dims", &attr_value)) { + MS_LOG(ERROR) << "Find Squeeze input squeeze_dims attr failed"; + return RET_ERROR; + } + auto dims = attr_value.list(); + for (int i = 0; i < dims.i_size(); ++i) { + attr->axis.push_back(dims.i(i)); + } + + primitive->value.type = schema::PrimitiveType_Squeeze; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfSqueezeParser("Squeeze", new TFSqueezeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.h new file mode 100644 index 0000000000..95a765df29 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_squeeze_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUEEZE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUEEZE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSqueezeParser : public TFNodeParser { + public: + TFSqueezeParser() = default; + ~TFSqueezeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUEEZE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc new file mode 100644 index 0000000000..b7255da91c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc @@ -0,0 +1,159 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_stride_slice_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF StrideSliceParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "begin_mask", &attr_value)) { + MS_LOG(ERROR) << "The begin_mask attr should be specified"; + return RET_ERROR; + } + attr->beginMask = attr_value.i(); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "end_mask", &attr_value)) { + MS_LOG(ERROR) << "The end_mask attr should be specified"; + return RET_ERROR; + } + attr->endMask = attr_value.i(); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "ellipsis_mask", &attr_value)) { + MS_LOG(ERROR) << "The ellipsis_mask attr should be specified"; + return RET_ERROR; + } + attr->ellipsisMask = attr_value.i(); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "new_axis_mask", &attr_value)) { + MS_LOG(ERROR) << "The new_axis_mask attr should be specified"; + return RET_ERROR; + } + attr->newAxisMask = attr_value.i(); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "shrink_axis_mask", &attr_value)) { + MS_LOG(ERROR) << "The shrink_axis_mask attr should be specified"; + return RET_ERROR; + } + attr->shrinkAxisMask = attr_value.i(); + + // begin + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find StridedSlice input begin failed"; + return RET_ERROR; + } + auto begin_node = tf_node_map.at(tf_op.input(1)); + if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->begin.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->begin.push_back(data[i]); + } + } + + // end + if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find StridedSlice input end failed"; + return RET_ERROR; + } + auto end_node = tf_node_map.at(tf_op.input(2)); + if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->end.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->end.push_back(data[i]); + } + } + + // strides + if (tf_node_map.find(tf_op.input(3)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find StridedSlice input strides failed"; + return RET_ERROR; + } + auto stride_node = tf_node_map.at(tf_op.input(3)); + if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->stride.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->stride.push_back(data[i]); + } + } + + primitive->value.type = schema::PrimitiveType_StridedSlice; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.h new file mode 100644 index 0000000000..2cbc75ba7d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_STRIDE_SLICE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_STRIDE_SLICE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFStrideSliceParser : public TFNodeParser { + public: + TFStrideSliceParser() = default; + ~TFStrideSliceParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_STRIDE_SLICE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc new file mode 100644 index 0000000000..42e6c131c5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tile_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTileParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TileParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Tile input multiplies failed"; + return RET_ERROR; + } + auto multiplies_node = tf_node_map.at(tf_op.input(1)); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(*multiplies_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->dims.push_back(i); + attr->multiples.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->dims.push_back(i); + attr->multiples.push_back(data[i]); + } + } + + primitive->value.type = schema::PrimitiveType_Tile; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfTileParser("Tile", new TFTileParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.h new file mode 100644 index 0000000000..fee9e31639 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TILE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TILE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTileParser : public TFNodeParser { + public: + TFTileParser() = default; + ~TFTileParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TILE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc new file mode 100644 index 0000000000..6698d48404 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_transpose_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TransposeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + attr->conjugate = false; + if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + MS_LOG(ERROR) << "Find Transpose input perm failed"; + return RET_ERROR; + } + auto perm_node = tf_node_map.at(tf_op.input(1)); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->perm.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->perm.push_back(data[i]); + } + } + + primitive->value.type = schema::PrimitiveType_Transpose; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfTransposeParser("Transpose", new TFTransposeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.h new file mode 100644 index 0000000000..1dd30d0532 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TRANSPOSE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TRANSPOSE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTransposeParser : public TFNodeParser { + public: + TFTransposeParser() = default; + ~TFTransposeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TRANSPOSE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index a2a4a7498a..7411e2f5d8 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -15,16 +15,33 @@ */ #include "tools/converter/parser/tf/tf_util.h" -#include -#include #include -#include "google/protobuf/io/zero_copy_stream_impl.h" +#include +#include "src/common/log_adapter.h" +#include "schema/inner/model_generated.h" namespace mindspore { namespace lite { -bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, +static const std::unordered_map TF_TYPE_MAP = { + {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, + {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, + {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, + {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, + {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, + {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, {tensorflow::DT_STRING, mindspore::kObjectTypeString}}; + +TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) { + auto iter = TF_TYPE_MAP.find(tf_data_type); + if (iter == TF_TYPE_MAP.end()) { + MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; + return kTypeUnknown; + } + return iter->second; +} + +bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &node_def, const std::string &attr_name, tensorflow::AttrValue *attr_value) { - const google::protobuf::Map &attr = nodeDef.attr(); + const google::protobuf::Map &attr = node_def.attr(); const google::protobuf::Map::const_iterator it = attr.find(attr_name); if (it != attr.end()) { *attr_value = it->second; @@ -32,5 +49,27 @@ bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &nodeDef, const st } return false; } + +TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name) { + tensorflow::AttrValue attr_value; + if (!FindAttrValue(node_def, attr_name, &attr_value)) { + MS_LOG(ERROR) << "Find attr failed: " << attr_name; + return kTypeUnknown; + } + return GetTFDataType(attr_value.type()); +} +schema::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_def) { + tensorflow::AttrValue attr_value; + if (!FindAttrValue(node_def, "data_format", &attr_value)) { + MS_LOG(ERROR) << "Find attr data_format failed"; + return schema::Format_NUM_OF_FORMAT; + } + if (attr_value.s() == "NHWC") { + return schema::Format_NHWC; + } else if (attr_value.s() == "NCHW") { + return schema::Format_NCHW; + } + return schema::Format_NUM_OF_FORMAT; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.h b/mindspore/lite/tools/converter/parser/tf/tf_util.h index 0c60defe58..f7e93376b6 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.h @@ -21,13 +21,17 @@ #include "proto/node_def.pb.h" #include "ir/dtype/type_id.h" #include "include/errorcode.h" +#include "schema/inner/model_generated.h" namespace mindspore { namespace lite { class TensorFlowUtils { public: - static bool FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, + static TypeId GetTFDataType(const tensorflow::DataType &tf_data_type); + static bool FindAttrValue(const tensorflow::NodeDef &node_def, const std::string &attr_name, tensorflow::AttrValue *attr_value); + static TypeId ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name); + static schema::Format ParseNodeFormat(const tensorflow::NodeDef &node_def); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 97cd6b2d54..0ef0543494 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -648,7 +648,7 @@ STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, *filterK = oriDims.at(lite::CKHW_K); *filterH = oriDims.at(lite::CKHW_H); *filterW = oriDims.at(lite::CKHW_W); - } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { + } else if (type == kHWCK2KCHW || type == kHWCK2CKHW || type == kHWCK2KHWC) { *filterH = oriDims.at(lite::HWCK_H); *filterW = oriDims.at(lite::HWCK_W); *filterC = oriDims.at(lite::HWCK_C); @@ -991,6 +991,20 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for return RET_ERROR; } break; + case schema::Format::Format_HWCK: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWCK2KHWC); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWCK2KHWC); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWCK2KHWC); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kHWCK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; case schema::Format::Format_KHWC: return RET_OK; default: diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index e0c4dd760f..070a0cea31 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -100,7 +100,8 @@ enum kTransFilterType { kKHWC2KCHW, kCKHW2KCHW, kCHWK2KCHW, - kKCHW2CKHW // 20 + kKCHW2CKHW, // 20 + kHWCK2KHWC }; STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index 5bc468da1b..e9c11634cb 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -20,6 +20,7 @@ using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_MS; using mindspore::lite::converter::FmkType_ONNX; +using mindspore::lite::converter::FmkType_TF; using mindspore::lite::converter::FmkType_TFLITE; using mindspore::schema::QuantType_AwareTraining; using mindspore::schema::QuantType_PostTraining; @@ -182,6 +183,22 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_nod return lite::RET_OK; } +lite::STATUS WeightFormatHardCodePass::HardCodeTF(const AnfNodePtr &conv_node, + const ParamValueLitePtr ¶m_value) const { + MS_ASSERT(conv_cnode != nullptr); + MS_ASSERT(param_value != nullptr); + auto op_type = GetCNodeType(conv_node); + + if (op_type == schema::PrimitiveType_Conv2D) { + param_value->set_format(schema::Format::Format_HWCK); + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) + << ", node: " << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + return lite::RET_OK; +} + bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { MS_ASSERT(graph != nullptr); auto node_list = TopoSort(graph->get_return()); @@ -215,6 +232,9 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { case FmkType_TFLITE: status = HardCodeTFLITE(node, param_value); break; + case FmkType_TF: + status = HardCodeTF(node, param_value); + break; case FmkType_ONNX: status = HardCodeONNX(node, param_value); break; diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h index 16ab442b1f..a46f6ab0d1 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h @@ -38,6 +38,7 @@ class WeightFormatHardCodePass : public Pass { lite::STATUS HardCodeONNX(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; lite::STATUS HardCodeMS(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; lite::STATUS HardCodeTFLITE(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; + lite::STATUS HardCodeTF(const AnfNodePtr &conv_node, const ParamValueLitePtr ¶m_value) const; private: QuantType quant_type = schema::QuantType_QUANT_NONE;