From 1a22007f1f5634c0eb67703bfa01babafd011dd2 Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 2 Feb 2021 09:46:50 +0800 Subject: [PATCH] pow-1 --- .../parser/tf/tf_arithmetic_self_parser.cc | 3 + .../converter/parser/tf/tf_one_hot_parser.cc | 67 +++++++++++++++++++ .../converter/parser/tf/tf_one_hot_parser.h | 36 ++++++++++ .../converter/parser/tf/tf_resize_parser.cc | 3 + 4 files changed, 109 insertions(+) create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.h diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc index 76d5812d82..b6939409ec 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc @@ -61,6 +61,8 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, status = CreateOperator(primitive, schema::PrimitiveType_Log); } else if (tf_op.op() == "Sqrt") { status = CreateOperator(primitive, schema::PrimitiveType_Sqrt); + } else if (tf_op.op() == "Pow") { + status = CreateOperator(primitive, schema::PrimitiveType_Power); } if (status != RET_OK) { return status; @@ -84,5 +86,6 @@ TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.cc new file mode 100644 index 0000000000..07144c651f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_one_hot_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFOneHotParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF OneHotParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) { + MS_LOG(ERROR) << "The axis attr should be specified"; + return RET_ERROR; + } + attr->axis = static_cast(attr_value.i()); + primitive->value.type = schema::PrimitiveType_OneHot; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfOneHotParser("OneHot", new TFOneHotParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.h new file mode 100644 index 0000000000..ead8ca7cd3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_one_hot_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFOneHotParser : public TFNodeParser { + public: + TFOneHotParser() = default; + ~TFOneHotParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_resize_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_resize_parser.cc index 3c2f18b744..1c8c365c88 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_resize_parser.cc @@ -56,6 +56,8 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op, attr->method = schema::ResizeMethod_LINEAR; } else if (tf_op.op() == "ResizeNearestNeighbor") { attr->method = schema::ResizeMethod_NEAREST; + } else if (tf_op.op() == "ResizeBicubic") { + attr->method = schema::ResizeMethod_CUBIC; } else { attr->method = schema::ResizeMethod_UNKNOWN; } @@ -90,5 +92,6 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op, } TFNodeRegistrar g_tfResizeBilinearParser("ResizeBilinear", new TFResizeParser()); TFNodeRegistrar g_tfResizeNearestNeighborParser("ResizeNearestNeighbor", new TFResizeParser()); +TFNodeRegistrar g_tfResizeBicubicParser("ResizeBicubic", new TFResizeParser()); } // namespace lite } // namespace mindspore