From a29fee5ff066a049e193db0130b83883f6d285f9 Mon Sep 17 00:00:00 2001 From: wangzhe Date: Fri, 4 Dec 2020 16:49:43 +0800 Subject: [PATCH] tensorlist parsers --- mindspore/core/ir/dtype_extends.cc | 4 + mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/primitive_c.cc | 15 ++++ .../parser/tf/tf_activation_parser.cc | 1 + .../parser/tf/tf_arithmetic_parser.cc | 18 ++++ .../converter/parser/tf/tf_assert_parser.cc | 13 +-- .../converter/parser/tf/tf_assert_parser.h | 1 - .../converter/parser/tf/tf_conv_parser.cc | 4 + .../converter/parser/tf/tf_model_parser.cc | 23 ++++- .../tf/tf_tensor_list_from_tensor_parser.cc | 87 +++++++++++++++++++ .../tf/tf_tensor_list_from_tensor_parser.h | 37 ++++++++ .../tf/tf_tensor_list_get_item_parser.cc | 76 ++++++++++++++++ .../tf/tf_tensor_list_get_item_parser.h | 38 ++++++++ .../tf/tf_tensor_list_reserve_parser.cc | 86 ++++++++++++++++++ .../parser/tf/tf_tensor_list_reserve_parser.h | 37 ++++++++ .../tf/tf_tensor_list_set_item_parser.cc | 76 ++++++++++++++++ .../tf/tf_tensor_list_set_item_parser.h | 37 ++++++++ .../parser/tf/tf_tensor_list_stack_parser.cc | 81 +++++++++++++++++ .../parser/tf/tf_tensor_list_stack_parser.h | 37 ++++++++ .../lite/tools/converter/parser/tf/tf_util.cc | 21 +++-- 20 files changed, 677 insertions(+), 16 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.h diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc index 7c69807138..d08d92a3cd 100644 --- a/mindspore/core/ir/dtype_extends.cc +++ b/mindspore/core/ir/dtype_extends.cc @@ -77,6 +77,8 @@ TypePtr TypeIdToType(TypeId id) { return kInt16; case kNumberTypeInt32: return kInt32; + case kNumberTypeInt: + return kInt32; case kNumberTypeInt64: return kInt64; case kNumberTypeUInt8: @@ -119,6 +121,8 @@ TypePtr TypeIdToType(TypeId id) { return kSlice; case kObjectTypeKeyword: return kKeyword; + case kObjectTypeTensorType: + return kTensorType; case kTypeUnknown: return kTypeNone; default: diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 64fea0ba66..250c271a46 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1194,6 +1194,7 @@ table TensorListSetItem { table TensorListReserve { elementDType : int; + shapeType : int; } table All { diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index ddf28dde64..7bae6bd8fc 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -150,6 +150,11 @@ #include "src/ops/unsorted_segment_sum.h" #include "src/ops/reciprocal.h" #include "src/ops/constant.h" +#include "src/ops/tensorlistfromtensor.h" +#include "src/ops/tensorlistgetitem.h" +#include "src/ops/tensorlistsetitem.h" +#include "src/ops/tensorlistreserve.h" +#include "src/ops/tensorliststack.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -906,6 +911,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Reciprocal(primitive); case schema::PrimitiveType_Constant: return new (std::nothrow) Constant(primitive); + case schema::PrimitiveType_TensorListFromTensor: + return new (std::nothrow) TensorListFromTensor(primitive); + case schema::PrimitiveType_TensorListGetItem: + return new (std::nothrow) TensorListGetItem(primitive); + case schema::PrimitiveType_TensorListSetItem: + return new (std::nothrow) TensorListSetItem(primitive); + case schema::PrimitiveType_TensorListReserve: + return new (std::nothrow) TensorListReserve(primitive); + case schema::PrimitiveType_TensorListStack: + return new (std::nothrow) TensorListStack(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc index 2b2cda5976..1f29d0f3da 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc @@ -52,6 +52,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, attr->type = schema::ActivationType_TANH; } else { MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); + return RET_ERROR; } primitive->value.type = schema::PrimitiveType_Activation; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc index fbcb0029e2..10d4a1e6cb 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc @@ -117,6 +117,22 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, } primitive->value.type = schema::PrimitiveType_LessEqual; primitive->value.value = attr.release(); + } else if (tf_op.op() == "Equal") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_Equal; + primitive->value.value = attr.release(); + } else if (tf_op.op() == "NotEqual") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_NotEqual; + primitive->value.value = attr.release(); } *primitiveC = PrimitiveC::Create(primitive.release()); @@ -144,5 +160,7 @@ TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); TFNodeRegistrar g_tfGreaterEqualParser("GreaterEqual", new TFArithmeticParser()); TFNodeRegistrar g_tfLessParser("Less", new TFArithmeticParser()); TFNodeRegistrar g_tfLessEqualParser("LessEqual", new TFArithmeticParser()); +TFNodeRegistrar g_tfEqualParser("Equal", new TFArithmeticParser()); +TFNodeRegistrar g_tfNotEqualParser("NotEqual", new TFArithmeticParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc index f9da640ff0..1cc480453c 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc @@ -41,6 +41,7 @@ STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "new attr failed"; return RET_NULL_PTR; } + tensorflow::AttrValue attr_value; if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { MS_LOG(ERROR) << "The keep_dims attr should be specified"; @@ -56,12 +57,14 @@ STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, return RET_ERROR; } - *output_size = 1; - auto status = AddOpInput(tf_op, 0, inputs); - if (status != RET_OK) { - return status; + *output_size = 0; // Assert not have output + for (int i = 0; i < tf_op.input_size(); ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } } - return status; + return RET_OK; } TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h index cf00d1f997..818cf15b5d 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h @@ -15,7 +15,6 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ - #include #include #include diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc index 688792607c..a790cb5e93 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc @@ -45,6 +45,10 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, attr->group = 1; attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); + if (attr->format == schema::Format_NCHW) { + MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now"; + return RET_ERROR; + } std::vector dilations(2); auto status = ParseDilations(tf_op, attr->format, &dilations); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index baa688a486..7874935a1f 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -25,10 +25,17 @@ #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/optimizer/common/gllo_utils.h" namespace mindspore { namespace lite { namespace { +static const std::vector tensorListOutputOpList = { + schema::PrimitiveType_TensorListFromTensor, + schema::PrimitiveType_TensorListSetItem, + schema::PrimitiveType_TensorListReserve, +}; + // subgraph node input may be a:output:0/a:z:0 std::string GetFlattenNodeName(std::string input_name) { std::regex re("\\:+"); @@ -107,7 +114,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value } tensor_size = shape_size * sizeof(float); param_value->SetTensorData(tensor_data, tensor_size); - } else if (type == kNumberTypeInt32) { + } else if (type == kNumberTypeInt32 || type == kNumberTypeInt) { auto tensor_data = new (std::nothrow) int[shape_size]; if (tensor_proto.int_val_size() == 1) { int value = tensor_proto.int_val(0); @@ -445,9 +452,19 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C MS_ASSERT(op != nullptr); MS_ASSERT(anf_node != nullptr); MS_ASSERT(anf_graph != nullptr); - if (output_size == 1) { + if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { + MS_LOG(ERROR) << "tensorlist output op output_size !=1"; + return RET_ERROR; + } + if (output_size == 0) { + return RET_OK; + } else if (output_size == 1) { + auto type = kFloat32; std::vector shape_vector; - anf_node->set_abstract(std::make_shared(kFloat32, shape_vector)); + if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { + type = TypeIdToType(kObjectTypeTensorType); + } + anf_node->set_abstract(std::make_shared(type, shape_vector)); anf_node_map->insert(std::pair(op.name(), anf_node)); } else { AbstractBasePtrList abstractList; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.cc new file mode 100644 index 0000000000..07b5376073 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, + int *output_size) { + MS_LOG(INFO) << "TF TensorListFromTensorParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { + MS_LOG(ERROR) << "The element_dtype attr should be specified"; + return RET_ERROR; + } + auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_from_tensor element dtype must be known type"; + return RET_ERROR; + } + attr->elementDType = type; + + if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { + MS_LOG(ERROR) << "The shape_type attr should be specified"; + return RET_ERROR; + } + type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_from_tensor shape type must be known type"; + return RET_ERROR; + } + attr->shapeType = type; + + primitive->value.type = schema::PrimitiveType_TensorListFromTensor; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfTensorListFromTensorParser("TensorListFromTensor", new TFTensorListFromTensorParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h new file mode 100644 index 0000000000..5cb732867a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorListFromTensorParser : public TFNodeParser { + public: + TFTensorListFromTensorParser() = default; + ~TFTensorListFromTensorParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.cc new file mode 100644 index 0000000000..6071939e85 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_list_get_item_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTensorListGetItemParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TensorListGetItemParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { + MS_LOG(ERROR) << "The element_dtype attr should be specified"; + return RET_ERROR; + } + auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_get_item element_dtype must be known type"; + return RET_ERROR; + } + attr->elementDType = type; + + primitive->value.type = schema::PrimitiveType_TensorListGetItem; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < 3; ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfTensorListGetItemParser("TensorListGetItem", new TFTensorListGetItemParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.h new file mode 100644 index 0000000000..37e5076947 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_get_item_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorListGetItemParser : public TFNodeParser { + public: + TFTensorListGetItemParser() = default; + ~TFTensorListGetItemParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.cc new file mode 100644 index 0000000000..6b6139c54f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_list_reserve_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTensorListReserveParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TensorListReserveParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { + MS_LOG(ERROR) << "The element_dtype attr should be specified"; + return RET_ERROR; + } + auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_reserve element dtype must be known type"; + return RET_ERROR; + } + attr->elementDType = type; + + if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { + MS_LOG(ERROR) << "The shape_type attr should be specified"; + return RET_ERROR; + } + type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_reserve shape_type must be known type"; + return RET_ERROR; + } + attr->shapeType = type; + + primitive->value.type = schema::PrimitiveType_TensorListReserve; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfTensorListReserveParser("TensorListReserve", new TFTensorListReserveParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.h new file mode 100644 index 0000000000..a9c81ba830 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_reserve_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorListReserveParser : public TFNodeParser { + public: + TFTensorListReserveParser() = default; + ~TFTensorListReserveParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.cc new file mode 100644 index 0000000000..ac86daebf4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_list_set_item_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTensorListSetItemParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TensorListSetItemParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { + MS_LOG(ERROR) << "The element_dtype attr should be specified"; + return RET_ERROR; + } + auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_set_item element dtype must be known type"; + return RET_ERROR; + } + attr->elementDType = type; + + primitive->value.type = schema::PrimitiveType_TensorListSetItem; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < 3; ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfTensorListSetItemParser("TensorListSetItem", new TFTensorListSetItemParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.h new file mode 100644 index 0000000000..b7c3f19049 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_set_item_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorListSetItemParser : public TFNodeParser { + public: + TFTensorListSetItemParser() = default; + ~TFTensorListSetItemParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.cc new file mode 100644 index 0000000000..18af91461c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_list_stack_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTensorListStackParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TensorListStackParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { + MS_LOG(ERROR) << "The element_dtype attr should be specified"; + return RET_ERROR; + } + auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "tensor_list_stack element_dtype must be known type"; + return RET_ERROR; + } + attr->elementDType = type; + + if (!TensorFlowUtils::FindAttrValue(tf_op, "num_elements", &attr_value)) { + MS_LOG(ERROR) << "The element_dtype attr should be specified"; + return RET_ERROR; + } + attr->numElements = attr_value.i(); + + primitive->value.type = schema::PrimitiveType_TensorListStack; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfTensorListStackParser("TensorListStack", new TFTensorListStackParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.h new file mode 100644 index 0000000000..f39777b447 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_list_stack_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorListStackParser : public TFNodeParser { + public: + TFTensorListStackParser() = default; + ~TFTensorListStackParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index 7411e2f5d8..622fdad363 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -23,17 +23,24 @@ namespace mindspore { namespace lite { static const std::unordered_map TF_TYPE_MAP = { - {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, - {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, - {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, - {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, - {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, - {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, {tensorflow::DT_STRING, mindspore::kObjectTypeString}}; + {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, + {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, + {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, + {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, + {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, + {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, + {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, + {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, + {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, + {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, + {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, + {tensorflow::DT_STRING, mindspore::kObjectTypeString}, + {tensorflow::DT_VARIANT, mindspore::kObjectTypeTensorType}}; TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) { auto iter = TF_TYPE_MAP.find(tf_data_type); if (iter == TF_TYPE_MAP.end()) { - MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; + MS_LOG(WARNING) << "unsupported TF data type: " << tf_data_type; return kTypeUnknown; } return iter->second;