diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index e48124ffdc..9b297eee59 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -1103,10 +1103,13 @@ schema::QuantType PrimitiveC::quant_type() const { return quant_type_; } #endif int PrimitiveC::Type() const { - if (this->primitive_ == nullptr) { + if (this->primitive_ == nullptr && this->op_type_ == OP_TYPE_NOT_SET) { return schema::PrimitiveType_NONE; } #ifdef PRIMITIVE_WRITEABLE + if (op_type_ != OP_TYPE_NOT_SET) { + return op_type_; + } return this->primitive_->value.type; #else return this->primitive_->value_type(); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 51f96b3466..4e2c029cd4 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -24,6 +24,9 @@ #ifdef PRIMITIVE_WRITEABLE #include "ir/primitive.h" #include "schema/inner/model_generated.h" +#include "schema/inner/ops_generated.h" +#include "schema/ops_generated.h" +#include "tools/converter/ops/ops_def.h" #else #include "schema/model_generated.h" #endif @@ -34,6 +37,7 @@ namespace mindspore { namespace lite { +constexpr const int OP_TYPE_NOT_SET = -1; constexpr uint32_t kSingleNum = 1; constexpr uint32_t kDoubleNum = 2; constexpr uint32_t kMultiNum = 3; @@ -149,6 +153,7 @@ class PrimitiveC : public mindspore::Primitive { std::vector> output_quant_param_; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; bool infer_flag_ = true; + int op_type_ = OP_TYPE_NOT_SET; }; std::shared_ptr GetReturnPrim(); @@ -227,6 +232,7 @@ class PrimitiveC { char *primitive_buf_ = nullptr; bool infer_flag_ = true; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; + int op_type_ = OP_TYPE_NOT_SET; }; using PrimitiveCPtr = std::shared_ptr; typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 1b592aaae9..8e118b2177 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -2,6 +2,7 @@ set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..) set(TEST_DIR ${TOP_DIR}/mindspore/lite/test) set(LITE_DIR ${TOP_DIR}/mindspore/lite) set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc) +set(CONVERTER_DIR ${TOP_DIR}/mindspore/lite/tools/converter) include_directories(${TOP_DIR}) include_directories(${TEST_DIR}) include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/external_libs/gtest.cmake) @@ -16,7 +17,7 @@ set(CCSRC_SRC ${CCSRC_DIR}/backend/optimizer/common/visit.cc ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc ) -else(ENABLE_CONVERTER) +else() set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc) add_compile_definitions(USE_ANDROID_LOG) endif() @@ -38,10 +39,10 @@ file(GLOB KERNEL_OP_TRAIN_SRC ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc ) -if (SUPPORT_TRAIN) +if(SUPPORT_TRAIN) list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) endif() -if (PLATFORM_ARM64) +if(PLATFORM_ARM64) # assembly file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/assembly/arm64/*.s ${LITE_DIR}/nnacl/assembly/arm64/*.S) @@ -53,7 +54,7 @@ if (PLATFORM_ARM64) ) endif() -if (PLATFORM_ARM32) +if(PLATFORM_ARM32) # assembly file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/assembly/arm32/*.S @@ -65,7 +66,7 @@ if (PLATFORM_ARM32) ) endif() -if ("${X86_64_SIMD}" STREQUAL "sse") +if("${X86_64_SIMD}" STREQUAL "sse") file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c) set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_OP_SRC @@ -74,12 +75,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse") ) endif() -if ("${X86_64_SIMD}" STREQUAL "avx") +if("${X86_64_SIMD}" STREQUAL "avx") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2") file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c - ${LITE_DIR}/nnacl/x86_64_avx/*.c - ${LITE_DIR}/nnacl/assembly/avx/*.S) + ${LITE_DIR}/nnacl/x86_64_avx/*.c + ${LITE_DIR}/nnacl/assembly/avx/*.S) set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_OP_SRC ${KERNEL_OP_SRC} @@ -88,7 +89,7 @@ if ("${X86_64_SIMD}" STREQUAL "avx") endif() ### gpu kernel -if (SUPPORT_GPU) +if(SUPPORT_GPU) file(GLOB GPU_KERNEL_OP_SRC ${LITE_DIR}/src/runtime/kernel/opencl/kernel/*.cc ) @@ -102,14 +103,18 @@ if (SUPPORT_GPU) ) endif() -if (PLATFORM_ARM32 OR PLATFORM_ARM64) - if (ENABLE_CONVERTER) +if(PLATFORM_ARM32 OR PLATFORM_ARM64) + if(ENABLE_CONVERTER) set(BUILD_MINDDATA "off") endif() endif() ### runtime framework add_definitions(-DENABLE_V0) file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc) +if(ENABLE_CONVERTER) + file(GLOB_RECURSE CONVERTER_OPS_SRC ${CONVERTER_DIR}/ops/*.cc) + set(OPS_SRC ${OPS_SRC} ${CONVERTER_OPS_SRC}) +endif() set(TEST_LITE_SRC ${TEST_LITE_SRC} ${CCSRC_SRC} @@ -144,7 +149,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/errorcode.cc ) ### gpu runtime -if (SUPPORT_GPU) +if(SUPPORT_GPU) include_directories(${TOP_DIR}/third_party/OpenCL-Headers) include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include) set(OPENCL_RUNTIME_SRC @@ -210,13 +215,14 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/while_pass.cc ${LITE_DIR}/tools/optimizer/graph/if_pass.cc + ${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc + ${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc ) endif() ### train -if (SUPPORT_TRAIN) +if(SUPPORT_TRAIN) set(TEST_LITE_SRC ${TEST_LITE_SRC} - # ${LITE_DIR}/src/train/ops/train_ops.cc ${LITE_DIR}/src/train/train_populate_parameter.cc ${LITE_DIR}/src/train/train_session.cc ${LITE_DIR}/src/train/train_model.cc @@ -251,7 +257,7 @@ set(TEST_SRC ${TEST_DIR}/ut/src/scheduler_test.cc ) -if (ENABLE_CONVERTER) +if(ENABLE_CONVERTER) set(TEST_SRC ${TEST_SRC} ${TEST_DIR}/st/converter_test.cc @@ -265,7 +271,7 @@ if (ENABLE_CONVERTER) ) endif() -if (SUPPORT_TRAIN) +if(SUPPORT_TRAIN) set(TEST_SRC ${TEST_SRC} ${TEST_CASE_KERNEL_TRAIN_SRC} @@ -278,7 +284,7 @@ else() ) endif() -if (SUPPORT_GPU) +if(SUPPORT_GPU) file(GLOB_RECURSE TEST_CASE_KERNEL_GPU_SRC ${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc ) @@ -288,7 +294,7 @@ if (SUPPORT_GPU) ) endif() -if (ENABLE_FP16) +if(ENABLE_FP16) file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc ) @@ -296,24 +302,24 @@ if (ENABLE_FP16) ${TEST_SRC} ${TEST_CASE_KERNEL_FP16_SRC} ) -endif () +endif() add_executable(lite-test ${TEST_SRC}) add_dependencies(lite-test fbs_src) target_link_libraries(lite-test dl mindspore::gtest) -if (PLATFORM_ARM64 AND ENABLE_FP16) +if(PLATFORM_ARM64 AND ENABLE_FP16) target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) endif() -if (PLATFORM_ARM) +if(PLATFORM_ARM) target_link_libraries(lite-test log) endif() -if (SUPPORT_NPU) +if(SUPPORT_NPU) include_directories(${DDK_PATH}) target_link_libraries(lite-test npu_kernel_mid) -endif () -if (ENABLE_CONVERTER) +endif() +if(ENABLE_CONVERTER) add_dependencies(lite-test fbs_inner_src) target_link_libraries(lite-test anf_importer_mid diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index aa3d61c558..262b581354 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -12,6 +12,8 @@ include(${TOP_DIR}/cmake/external_libs/glog.cmake) file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc) +file(GLOB CONVERTER_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc) + file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc @@ -65,6 +67,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/while_pass.cc ../optimizer/graph/if_pass.cc ../optimizer/graph/mindir_inputs_adjust_pass.cc + ../optimizer/graph/functionalize_control_op_pass.cc + ../optimizer/graph/functionalize_while.cc ) add_subdirectory(../anf_importer anf_importer) @@ -97,12 +101,12 @@ set(LITE_SRC ${SRC_DIR}/errorcode.cc ${SRC_DIR}/dequant.cc ) -if (SUPPORT_TRAIN) +if(SUPPORT_TRAIN) set(LITE_SRC ${LITE_SRC} ) -endif () +endif() set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) file(GLOB KERNEL_SRC ${ARM_DIR}/base/*.cc @@ -114,13 +118,13 @@ file(GLOB KERNEL_SRC ${ARM_DIR}/int8/*.cc ) -if (PLATFORM_ARM64) +if(PLATFORM_ARM64) # assembly file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.s ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.S) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) -endif () +endif() file(GLOB PROTO_FILE "" ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto @@ -133,11 +137,13 @@ add_library(proto_mid OBJECT ${PROTO_SRCS}) set(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/schema.fbs ) -ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src ${CMAKE_BINARY_DIR}/schema "inner") +ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src + ${CMAKE_BINARY_DIR}/schema "inner") set_property(SOURCE ${CONVERTER_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) set_property(SOURCE ${CCSRC_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) set_property(SOURCE ${OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +set_property(SOURCE ${CONVERTER_OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) set_property(SOURCE ${KERNEL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_executable(converter_lite @@ -145,6 +151,7 @@ add_executable(converter_lite ${CCSRC_SRC} ${CONVERTER_SRC} ${OPS_SRC} + ${CONVERTER_OPS_SRC} ${KERNEL_SRC} ${LITE_SRC} ) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 2b3296fa16..7e4fe5d11e 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -48,6 +48,7 @@ #include "tools/optimizer/graph/slice_prepose_pass.h" #include "tools/optimizer/graph/while_pass.h" #include "tools/optimizer/graph/if_pass.h" +#include "tools/optimizer/graph/functionalize_control_op_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -100,6 +101,15 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap } } + if (config->fmk == lite::converter::FmkType_TF) { + auto functionalize_control_op_pass = std::make_shared(); + if (!functionalize_control_op_pass->Run(old_graph)) { + MS_LOG(ERROR) << "functionalize control op pass failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + } + if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || config->fmk == lite::converter::FmkType_ONNX) { graph_pm->AddPass(std::make_shared()); @@ -145,7 +155,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); if (remove_unused_cast_pass == nullptr) { - MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified"; + MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified"; return nullptr; } remove_unused_cast_pass->SetFmkType(config->fmk); @@ -154,7 +164,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap if (config->fmk == lite::converter::FmkType_ONNX) { auto remove_unused_transpose_pass = std::make_shared(); if (remove_unused_transpose_pass == nullptr) { - MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified"; + MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified"; return nullptr; } remove_unused_transpose_pass->SetFmkType(config->fmk); diff --git a/mindspore/lite/tools/converter/ops/enter.cc b/mindspore/lite/tools/converter/ops/enter.cc new file mode 100644 index 0000000000..3b74170042 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/enter.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/ops/enter.h" +#include "src/tensorlist.h" + +namespace mindspore { +namespace lite { + +int Enter::InferShape(std::vector inputs_, std::vector outputs_) { + if (!infer_flag()) { + return RET_INFER_INVALID; + } + for (size_t i = 0; i < inputs_.size(); i++) { + auto *input = inputs_[i]; + auto *output = outputs_[i]; + if (input == nullptr) { + MS_LOG(ERROR) << "input tensor is nullptr"; + return RET_ERROR; + } + if (output == nullptr) { + MS_LOG(ERROR) << "output tensor is nullptr"; + return RET_ERROR; + } + output->set_data_type(input->data_type()); + output->set_shape(input->shape()); + output->set_format(input->format()); + auto data_type = input->data_type(); + if (data_type != kObjectTypeTensorType) { + continue; + } else { + auto input_tensorlist = reinterpret_cast(input); + auto output_tensorlist = reinterpret_cast(output); + output_tensorlist->set_element_shape(input_tensorlist->element_shape()); + output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); + output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); + } + } + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/enter.h b/mindspore/lite/tools/converter/ops/enter.h new file mode 100644 index 0000000000..21f7baa7f2 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/enter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_ENTER_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ENTER_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { + +class Enter : public PrimitiveC { + public: + Enter() { op_type_ = ConverterPrimitiveType_Enter; } + ~Enter() = default; + MS_DECLARE_PARENT(Enter, PrimitiveC); + explicit Enter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_ENTER_H_ diff --git a/mindspore/lite/tools/converter/ops/exit.cc b/mindspore/lite/tools/converter/ops/exit.cc new file mode 100644 index 0000000000..7b6937390f --- /dev/null +++ b/mindspore/lite/tools/converter/ops/exit.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/ops/exit.h" +#include "src/tensorlist.h" + +namespace mindspore { +namespace lite { + +int Exit::InferShape(std::vector inputs_, std::vector outputs_) { + if (!infer_flag()) { + return RET_INFER_INVALID; + } + for (size_t i = 0; i < inputs_.size(); i++) { + auto *input = inputs_[i]; + auto *output = outputs_[i]; + if (input == nullptr) { + MS_LOG(ERROR) << "input tensor is nullptr"; + return RET_ERROR; + } + if (output == nullptr) { + MS_LOG(ERROR) << "output tensor is nullptr"; + return RET_ERROR; + } + output->set_data_type(input->data_type()); + output->set_shape(input->shape()); + output->set_format(input->format()); + auto data_type = input->data_type(); + if (data_type != kObjectTypeTensorType) { + continue; + } else { + auto input_tensorlist = reinterpret_cast(input); + auto output_tensorlist = reinterpret_cast(output); + output_tensorlist->set_element_shape(input_tensorlist->element_shape()); + output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); + output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); + } + } + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/exit.h b/mindspore/lite/tools/converter/ops/exit.h new file mode 100644 index 0000000000..69db67a8c0 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/exit.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_EXIT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_EXIT_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { + +class Exit : public PrimitiveC { + public: + Exit() { op_type_ = ConverterPrimitiveType_Exit; } + ~Exit() = default; + MS_DECLARE_PARENT(Exit, PrimitiveC); + explicit Exit(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_EXIT_H_ diff --git a/mindspore/lite/tools/converter/ops/loop_cond.cc b/mindspore/lite/tools/converter/ops/loop_cond.cc new file mode 100644 index 0000000000..3803af5fcc --- /dev/null +++ b/mindspore/lite/tools/converter/ops/loop_cond.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/ops/loop_cond.h" +#include "src/tensorlist.h" + +namespace mindspore { +namespace lite { + +int LoopCond::InferShape(std::vector inputs_, std::vector outputs_) { + if (!infer_flag()) { + return RET_INFER_INVALID; + } + for (size_t i = 0; i < inputs_.size(); i++) { + auto *input = inputs_[i]; + auto *output = outputs_[i]; + if (input == nullptr) { + MS_LOG(ERROR) << "input tensor is nullptr"; + return RET_ERROR; + } + if (output == nullptr) { + MS_LOG(ERROR) << "output tensor is nullptr"; + return RET_ERROR; + } + output->set_data_type(input->data_type()); + output->set_shape(input->shape()); + output->set_format(input->format()); + auto data_type = input->data_type(); + if (data_type != kObjectTypeTensorType) { + continue; + } else { + auto input_tensorlist = reinterpret_cast(input); + auto output_tensorlist = reinterpret_cast(output); + output_tensorlist->set_element_shape(input_tensorlist->element_shape()); + output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); + output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); + } + } + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/loop_cond.h b/mindspore/lite/tools/converter/ops/loop_cond.h new file mode 100644 index 0000000000..12eb6824b0 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/loop_cond.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { + +class LoopCond : public PrimitiveC { + public: + LoopCond() { op_type_ = ConverterPrimitiveType_LoopCond; } + ~LoopCond() = default; + MS_DECLARE_PARENT(LoopCond, PrimitiveC); + explicit LoopCond(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_ diff --git a/mindspore/lite/tools/converter/ops/next_iteration.cc b/mindspore/lite/tools/converter/ops/next_iteration.cc new file mode 100644 index 0000000000..6b554e39b5 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/next_iteration.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/ops/next_iteration.h" +#include "src/tensorlist.h" + +namespace mindspore { +namespace lite { + +int NextIteration::InferShape(std::vector inputs_, std::vector outputs_) { + if (!infer_flag()) { + return RET_INFER_INVALID; + } + for (size_t i = 0; i < inputs_.size(); i++) { + auto *input = inputs_[i]; + auto *output = outputs_[i]; + if (input == nullptr) { + MS_LOG(ERROR) << "input tensor is nullptr"; + return RET_ERROR; + } + if (output == nullptr) { + MS_LOG(ERROR) << "output tensor is nullptr"; + return RET_ERROR; + } + output->set_data_type(input->data_type()); + output->set_shape(input->shape()); + output->set_format(input->format()); + auto data_type = input->data_type(); + if (data_type != kObjectTypeTensorType) { + continue; + } else { + auto input_tensorlist = reinterpret_cast(input); + auto output_tensorlist = reinterpret_cast(output); + output_tensorlist->set_element_shape(input_tensorlist->element_shape()); + output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); + output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); + } + } + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/next_iteration.h b/mindspore/lite/tools/converter/ops/next_iteration.h new file mode 100644 index 0000000000..56d8b66767 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/next_iteration.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ +#define LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { + +class NextIteration : public PrimitiveC { + public: + NextIteration() { op_type_ = ConverterPrimitiveType_NextIteration; } + ~NextIteration() = default; + MS_DECLARE_PARENT(NextIteration, PrimitiveC); + explicit NextIteration(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ diff --git a/mindspore/lite/tools/converter/ops/ops_def.h b/mindspore/lite/tools/converter/ops/ops_def.h new file mode 100644 index 0000000000..16a565ae48 --- /dev/null +++ b/mindspore/lite/tools/converter/ops/ops_def.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ +#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { + +enum ConverterPrimitiveType { + ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1, + ConverterPrimitiveType_LoopCond, + ConverterPrimitiveType_NextIteration, + ConverterPrimitiveType_Exit, +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc new file mode 100644 index 0000000000..5a3785b4dc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/parser/tf/tf_enter_parser.h" +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/enter.h" + +namespace mindspore { +namespace lite { +STATUS TFEnterParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF EnterParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + *primitiveC = new (std::nothrow) Enter(); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfEnterParser("Enter", new TFEnterParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.h new file mode 100644 index 0000000000..e1f2119e11 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFEnterParser : public TFNodeParser { + public: + TFEnterParser() = default; + ~TFEnterParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc new file mode 100644 index 0000000000..8c42b5b27b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_exit_parser.h" +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/exit.h" + +namespace mindspore { +namespace lite { +STATUS TFExitParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ExitParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + *primitiveC = new (std::nothrow) Exit(); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfExitParser("Exit", new TFExitParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.h new file mode 100644 index 0000000000..f53bb92b40 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFExitParser : public TFNodeParser { + public: + TFExitParser() = default; + ~TFExitParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc new file mode 100644 index 0000000000..a1199aaf03 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_loop_cond_parser.h" +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/loop_cond.h" + +namespace mindspore { +namespace lite { +STATUS TFLoopCondParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF LoopCondParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + *primitiveC = new (std::nothrow) LoopCond(); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfLoopCondParser("LoopCond", new TFLoopCondParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.h new file mode 100644 index 0000000000..4ec996c47d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFLoopCondParser : public TFNodeParser { + public: + TFLoopCondParser() = default; + ~TFLoopCondParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc new file mode 100644 index 0000000000..ced56a501f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_merge_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFMergeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF MergeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Merge; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfMergeParser("Merge", new TFMergeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.h new file mode 100644 index 0000000000..400d1ce10c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFMergeParser : public TFNodeParser { + public: + TFMergeParser() = default; + ~TFMergeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 747bf3ae27..632cfa7ae3 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -277,7 +277,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co } param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t)); } else { - MS_LOG(ERROR) << "Unsupport dataType: " << type; + MS_LOG(ERROR) << "Unsupported dataType: " << type; return RET_ERROR; } @@ -417,6 +417,16 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin MS_LOG(ERROR) << "Convert ops failed."; return nullptr; } + + if (!nodes_with_null_input_.empty()) { + status = ConnectNullInput(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Connect null inputs failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + } + status = ConvertRootGraphOutputs(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph outputs failed."; @@ -474,16 +484,16 @@ STATUS TFModelParser::ConvertSubgraph() { std::vector sub_graph_inputs; for (int j = 0; j < input_arg_size; j++) { auto &input_arg = tf_sub_signature.input_arg(j); - auto paramter = sub_func_graph->add_parameter(); - paramter->set_name(input_arg.name()); - anf_sub_node_map[input_arg.name()] = paramter; + auto parameter = sub_func_graph->add_parameter(); + parameter->set_name(input_arg.name()); + anf_sub_node_map[input_arg.name()] = parameter; auto root_inputs = cnode->inputs(); if (op_type == schema::PrimitiveType_While) { - paramter->set_abstract(root_inputs[j + 1]->abstract()); + parameter->set_abstract(root_inputs[j + 1]->abstract()); } else { - paramter->set_abstract(root_inputs[j + 2]->abstract()); + parameter->set_abstract(root_inputs[j + 2]->abstract()); } - sub_graph_inputs.emplace_back(paramter); + sub_graph_inputs.emplace_back(parameter); } std::map tf_sub_node_map; for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { @@ -643,7 +653,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector &input_names, const std::map &tf_node_map, const std::unordered_map &anf_node_map, - std::vector *inputs) { + std::vector *inputs, + std::vector *input_name_not_found) { MS_ASSERT(node_def != nullptr); // parse inputs for (size_t j = 0; j < input_names.size(); j++) { @@ -656,8 +667,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, } auto input = GetAnfNode(flatten_input_name, anf_node_map); if (input == nullptr) { - MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; - return RET_ERROR; + MS_LOG(WARNING) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; + (*input_name_not_found).push_back(flatten_input_name); } inputs->emplace_back(input); } @@ -718,6 +729,27 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C return RET_OK; } +STATUS TFModelParser::RecordNullInput(const CNodePtr &node, const std::vector &input_name_not_found) { + nodes_with_null_input_.emplace_back(node, input_name_not_found); + return RET_OK; +} + +STATUS TFModelParser::ConnectNullInput() { + for (auto &it : nodes_with_null_input_) { + auto &cnode = it.first; + auto &input_name_not_found = it.second; + auto &inputs = cnode->inputs(); + int i = 0; + for (size_t j = 0; j < inputs.size(); ++j) { + if (inputs[j] == nullptr) { + cnode->set_input(j, GetAnfNode(input_name_not_found[i], anf_root_node_map_)); + ++i; + } + } + } + return RET_OK; +} + STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, const std::map &tf_node_map, const FuncGraphPtr &func_graph_ptr, @@ -752,7 +784,8 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, return RET_ERROR; } std::vector inputs = {value_node}; - status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs); + std::vector input_name_not_found{}; + status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs, &input_name_not_found); if (status != RET_OK) { return status; } @@ -787,6 +820,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, } } + if (!input_name_not_found.empty()) { + RecordNullInput(anf_node, input_name_not_found); + } + status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); if (status != RET_OK) { MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 0f20d7cbe9..cd16e76378 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "proto/graph.pb.h" #include "proto/node_def.pb.h" #include "schema/inner/model_generated.h" @@ -55,7 +56,7 @@ class TFModelParser : public ModelParser { STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector &input_names, const std::map &tf_node_map, const std::unordered_map &anf_node_map, - std::vector *inputs); + std::vector *inputs, std::vector *input_name_not_found); STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, std::unordered_map *anf_node_map, const FuncGraphPtr &anf_graph, int output_size); @@ -71,6 +72,10 @@ class TFModelParser : public ModelParser { STATUS MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph); + STATUS RecordNullInput(const CNodePtr &node, const std::vector &input_name_not_found); + + STATUS ConnectNullInput(); + FuncGraphPtr anf_root_graph_; std::unique_ptr tf_root_graph_; // tf root graph def std::map tf_root_graph_nodes_; // tf root graph node map @@ -79,6 +84,7 @@ class TFModelParser : public ModelParser { std::vector graph_output_names_; std::map function_while_map_; // tf function name->while_node_name std::map function_if_map_; // tf function name->if_node + std::vector>> nodes_with_null_input_{}; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc new file mode 100644 index 0000000000..9469266525 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_next_iteration_parser.h" +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/next_iteration.h" + +namespace mindspore { +namespace lite { +STATUS TFNextIterationParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF NextIterationParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + *primitiveC = new (std::nothrow) NextIteration(); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfNextIterationParser("NextIteration", new TFNextIterationParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.h new file mode 100644 index 0000000000..430a42fe13 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NEXT_ITERATION_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NEXT_ITERATION_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFNextIterationParser : public TFNodeParser { + public: + TFNextIterationParser() = default; + ~TFNextIterationParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc new file mode 100644 index 0000000000..0c87cf5fdf --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_switch_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SwitchParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Switch; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfSwitchParser("Switch", new TFSwitchParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.h new file mode 100644 index 0000000000..7874a0f7be --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SWITCH_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SWITCH_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSwitchParser : public TFNodeParser { + public: + TFSwitchParser() = default; + ~TFSwitchParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc new file mode 100644 index 0000000000..b878a55304 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "tools/optimizer/graph/functionalize_control_op_pass.h" +#include "tools/optimizer/graph/functionalize_while.h" +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { + +FuncGraphPtr FunctionalizeControlOpPass::NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type) { + auto fg = std::make_shared(); + if (fg == nullptr) { + MS_LOG(ERROR) << "new func)graph failed."; + return nullptr; + } + fg->set_attr("graph_name", MakeValue(subgraph_name)); + fg->set_attr("fmk", MakeValue(static_cast(fmk_type))); + return fg; +} + +std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) { + std::string cluster_name{}; + // tf node name use '/' split node name + auto cnode = utils::cast(node); + size_t pos = cnode->fullname_with_scope().rfind('/'); + if (pos != std::string::npos) { + cluster_name = cnode->fullname_with_scope().substr(0, pos); + } else { + cluster_name = cnode->fullname_with_scope(); + } + return cluster_name; +} + +void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { + for (auto &node : func_graph->nodes()) { + auto cluster_name = NodeClusterName(node); + auto cluster_pos = WhichCluster(cluster_name); + if (cluster_pos == node_clusters_.size()) { + std::vector node_list{node}; + node_clusters_.emplace_back(std::make_pair(cluster_name, node_list)); + } else { + node_clusters_[cluster_pos].second.push_back(node); + } + } +} + +size_t FunctionalizeControlOpPass::WhichCluster(const std::string &cluster_name) { + size_t pos = node_clusters_.size(); + for (size_t i = 0; i < pos; ++i) { + if (node_clusters_[i].first == cluster_name) { + return i; + } + } + return pos; +} + +STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_graph) { + int ret = RET_OK; + for (auto &node_cluster : node_clusters_) { + for (auto &node : node_cluster.second) { + if (IsLoopCond(node)) { + loop_cond_nodes_.push_back(node->cast()); + FunctionalizeWhile fw(node_cluster.second, node->cast(), func_graph); + ret = fw.Process(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; + return ret; + } + } + } + } + return ret; +} + +bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) { + // use name to find the frame + InitNodeClusters(func_graph); + if (BuildWhileSubgraph(func_graph) != RET_OK) { + MS_LOG(ERROR) << "build while subgraph failed."; + return false; + } + return true; +} +CNodePtr FunctionalizeControlOpPass::BelongToWhichNode(const CNodePtr &node, const FilterFunc &func) { + if (node == nullptr) { + return nullptr; + } + if (func(node)) { + return node; + } + CNodePtr aim_node = nullptr; + std::deque todo(256); + todo.clear(); + for (auto &input_node : node->inputs()) { + if (func(input_node)) { + aim_node = utils::cast(input_node); + todo.clear(); + break; + } + todo.push_back(input_node); + } + + while (!todo.empty()) { + AnfNodePtr todo_node = todo.front(); + todo.pop_front(); + if (func(todo_node)) { + aim_node = utils::cast(todo_node); + todo.clear(); + break; + } + if (utils::isa(todo_node)) { + auto cnode = utils::cast(todo_node); + for (size_t i = 0; i < cnode->inputs().size(); i++) { + todo.push_back(cnode->input(i)); + } + } + } + if (aim_node == nullptr) { + MS_LOG(WARNING) << "not found belonging enter node."; + return nullptr; + } + + return aim_node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h new file mode 100644 index 0000000000..68f305e8b5 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_ +#include +#include +#include +#include +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class FunctionalizeControlOpPass : public Pass { + public: + FunctionalizeControlOpPass() : Pass("functionalize_control_op_pass") {} + ~FunctionalizeControlOpPass() override = default; + bool Run(const FuncGraphPtr &graph) override; + static FuncGraphPtr NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type); + static bool IsMerge(const AnfNodePtr &node) { return opt::GetCNodeType(node) == schema::PrimitiveType_Merge; } + static bool IsLoopCond(const AnfNodePtr &node) { + return static_cast(opt::GetCNodeType(node)) == static_cast(lite::ConverterPrimitiveType_LoopCond); + } + static bool IsEnter(const AnfNodePtr &node) { + return static_cast(opt::GetCNodeType(node)) == static_cast(lite::ConverterPrimitiveType_Enter); + } + static bool IsExit(const AnfNodePtr &node) { + return static_cast(opt::GetCNodeType(node)) == static_cast(lite::ConverterPrimitiveType_Exit); + } + static bool IsSwitch(const AnfNodePtr &node) { return opt::GetCNodeType(node) == schema::PrimitiveType_Switch; } + static bool IsNextIteration(const AnfNodePtr &node) { + return static_cast(opt::GetCNodeType(node)) == static_cast(lite::ConverterPrimitiveType_NextIteration); + } + static bool IsControlFlowOp(const AnfNodePtr &node) { + return IsLoopCond(node) || IsEnter(node) || IsMerge(node) || IsSwitch(node) || IsExit(node) || + IsNextIteration(node); + } + static CNodePtr BelongToWhichNode(const CNodePtr &node, const FilterFunc &func); + static int GetSubgraphIndex() { + static int subgraph_index = 1; + return subgraph_index++; + } + // The names of nodes with the same prefix are a cluster. + static std::string NodeClusterName(const AnfNodePtr &node); + void InitNodeClusters(const FuncGraphPtr &func_graph); + // return the position in node_clusters_ + size_t WhichCluster(const std::string &cluster_name); + + protected: + STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph); + std::vector>> node_clusters_{}; + std::vector loop_cond_nodes_{}; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc new file mode 100644 index 0000000000..eed05a7dd6 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc @@ -0,0 +1,521 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/optimizer/graph/functionalize_while.h" +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" +#include "src/ops/while.h" + +namespace { +mindspore::ValueNodePtr GetWhileAnfPrim() { + auto while_primitiveT = new (std::nothrow) mindspore::schema::PrimitiveT; + if (while_primitiveT == nullptr) { + MS_LOG(ERROR) << "new while_primitiveT failed"; + return nullptr; + } + while_primitiveT->value.type = mindspore::schema::PrimitiveType_While; + auto whileT = new (std::nothrow) mindspore::schema::WhileT; + whileT->condSubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex(); + whileT->bodySubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex(); + while_primitiveT->value.value = whileT; + if (while_primitiveT->value.value == nullptr) { + MS_LOG(ERROR) << "new WhileT failed"; + delete (while_primitiveT); + return nullptr; + } + + auto while_prim = std::make_shared(while_primitiveT); + mindspore::ValueNodePtr partial_anf_prim = NewValueNode(while_prim); + return partial_anf_prim; +} +} // namespace + +namespace mindspore::opt { + +using mindspore::lite::RET_NULL_PTR; + +CNodePtr FunctionalizeWhile::BlongToWhichSwitch(const CNodePtr &node) { + return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsSwitch); +} +CNodePtr FunctionalizeWhile::BlongToWhichMerge(const CNodePtr &node) { + return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsMerge); +} +CNodePtr FunctionalizeWhile::BlongToWhichEnter(const CNodePtr &node) { + return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsEnter); +} + +int FunctionalizeWhile::PosInInputEnterNodes(const CNodePtr &node) { + auto index = std::find(input_enter_nodes_.begin(), input_enter_nodes_.end(), node); + if (index == input_enter_nodes_.end()) { + MS_LOG(WARNING) << node->fullname_with_scope() << " is not in input_enter_nodes_"; + return -1; + } + return index - input_enter_nodes_.begin(); +} + +STATUS FunctionalizeWhile::NewWhileNode() { + ValueNodePtr while_anf_primitive = GetWhileAnfPrim(); + if (while_anf_primitive == nullptr) { + MS_LOG(ERROR) << "Get while anf primitive failed."; + return RET_NULL_PTR; + } + + static int count = 0; + std::vector while_op_inputs = {while_anf_primitive}; + while_node_ = fg_->NewCNode(while_op_inputs); + while_node_->set_fullname_with_scope(loop_cond_node_->fullname_with_scope() + "-while-" + std::to_string(count++)); + return RET_OK; +} + +STATUS FunctionalizeWhile::IdentifyWhileNodeInput() { + for (auto &node : node_cluster_) { + if (FunctionalizeControlOpPass::IsEnter(node)) { + auto enter_cnode = node->cast(); + input_enter_nodes_.push_back(enter_cnode); + while_node_->add_input(enter_cnode->input(1)); + } + } + if (input_enter_nodes_.empty()) { + MS_LOG(ERROR) << "not found input of while node."; + return RET_ERROR; + } + return RET_OK; +} + +STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() { + output_exit_nodes_.resize(input_enter_nodes_.size()); + for (auto &node : node_cluster_) { + // exit ->switch->merge->enter + if (FunctionalizeControlOpPass::IsExit(node)) { + auto exit_node = node->cast(); + auto switch_node = BlongToWhichSwitch(exit_node); + auto merge_node = BlongToWhichMerge(switch_node); + auto enter_node = BlongToWhichEnter(merge_node); + int pos = PosInInputEnterNodes(enter_node); + if (pos == -1) { + MS_LOG(ERROR) << "not find in input enter nodes."; + return RET_ERROR; + } + output_exit_nodes_.at(pos) = exit_node; + } + } + + if (output_exit_nodes_.size() == 1) { + while_node_->set_abstract(output_exit_nodes_[0]->abstract()); + } else { + AbstractBasePtrList abstract_list; + abstract_list.resize(output_exit_nodes_.size()); + std::transform(output_exit_nodes_.begin(), output_exit_nodes_.end(), abstract_list.begin(), + [](const CNodePtr &cnode) { return cnode->abstract(); }); + while_node_->set_abstract(std::make_shared(abstract_list)); + } + return RET_OK; +} + +STATUS FunctionalizeWhile::UpdateExitNodeUser() { + if (output_exit_nodes_.size() == 1) { + auto manager = fg_->manager(); + auto node_users = manager->node_users()[output_exit_nodes_[0]]; + for (auto &node_user : node_users) { + if (fg_->nodes().contains(node_user.first)) { + manager->SetEdge(node_user.first, node_user.second, while_node_); + } + } + } else { + for (auto &node : output_exit_nodes_) { + auto manager = fg_->manager(); + auto node_users = manager->node_users()[node]; + for (auto &node_user : node_users) { + // new getitem + AbstractBasePtrList abstractList; + std::vector shape_vector; + abstractList.emplace_back(std::make_shared(kFloat32, shape_vector)); + auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim(); + if (tuple_get_item_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; + return RET_NULL_PTR; + } + auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); + const auto &exit_node = node; + auto switch_node = BlongToWhichSwitch(exit_node); + auto merge_node = BlongToWhichMerge(switch_node); + auto enter_node = BlongToWhichEnter(merge_node); + int output_idx = PosInInputEnterNodes(enter_node); + auto getItemValue = NewValueNode(MakeValue(output_idx)); + std::vector inputs{tuple_get_item_prim, while_node_, getItemValue}; + CNodePtr get_item_node = fg_->NewCNode(inputs); + std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); + auto abstract = std::make_shared(kFloat32, shape_vector); + if (abstract == nullptr) { + MS_LOG(ERROR) << "create AbstractTensor failed"; + return RET_NULL_PTR; + } + get_item_node->set_abstract(abstract); + get_item_node->set_fullname_with_scope(output_item_name); + // set + if (fg_->nodes().contains(node_user.first)) { + manager->SetEdge(node_user.first, node_user.second, get_item_node); + } + } + } + } + return RET_OK; +} + +STATUS FunctionalizeWhile::BuildWhileNode() { + int ret = NewWhileNode(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "new while node failed, ret:" << ret; + return ret; + } + ret = IdentifyWhileNodeInput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "identify while node input failed, ret:" << ret; + return ret; + } + ret = IdentifyWhileNodeOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "identify while node output failed, ret:" << ret; + return ret; + } + // update exit node user from exit to while + ret = UpdateExitNodeUser(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "update while node users, ret:" << ret; + return ret; + } + + return ret; +} + +// nodes between loop_cond op and merge op be added into cond_func_graph +STATUS FunctionalizeWhile::CondSubgraphAddNodes() { + std::deque todo(512); + todo.clear(); + for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) { + todo.push_back(loop_cond_node_->input(i)); + } + while (!todo.empty()) { + AnfNodePtr node = todo.back(); + todo.pop_back(); + if (FunctionalizeControlOpPass::IsMerge(node)) { + continue; + } + if (utils::isa(node)) { + cond_sub_func_graph_->add_parameter(node->cast()); + } else { + cond_sub_func_graph_->AddNode(node); + } + node->set_func_graph(cond_sub_func_graph_); + if (utils::isa(node)) { + auto cnode = utils::cast(node); + for (size_t i = 1; i < cnode->inputs().size(); i++) { + todo.push_back(cnode->input(i)); + } + } + } + return RET_OK; +} + +STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() { + std::vector nodes_need_drop{}; + for (auto &cnode : cond_sub_func_graph_->GetOrderedCnodes()) { + for (auto &input_node : cnode->inputs()) { + if (FunctionalizeControlOpPass::IsMerge(input_node)) { + auto merge_node = input_node->cast(); + auto enter_node = BlongToWhichEnter(merge_node); + int pos = PosInInputEnterNodes(enter_node); + nodes_need_drop.push_back(cnode); + + // set parameter + auto parameter = cond_sub_func_graph_->add_parameter(); + parameter->set_abstract(cnode->abstract()); + // hardcode for subgraph input name + parameter->set_name(cond_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter"); + + // replace merge + auto manager = fg_->manager(); + auto node_users = manager->node_users()[cnode]; + for (auto &node_user : node_users) { + if (cond_sub_func_graph_->nodes().contains(node_user.first)) { + manager->SetEdge(node_user.first, node_user.second, parameter); + } + } + } + } + } + + // drop node from cond_func_graph + for (const auto &node : nodes_need_drop) { + cond_sub_func_graph_->DropNode(node); + } + return RET_OK; +} + +STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() { + auto return_prim_ptr = lite::GetReturnPrim(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return RET_NULL_PTR; + } + auto value_node = NewValueNode(return_prim_ptr); + if (value_node == nullptr) { + MS_LOG(ERROR) << "new value_node failed."; + return RET_NULL_PTR; + } + // cond subgraph output is LoopCond's input + std::vector op_inputs{value_node, loop_cond_node_->input(1)}; + auto return_cnode = cond_sub_func_graph_->NewCNode(op_inputs); + return_cnode->set_fullname_with_scope(cond_subgraph_name_ + "-return"); + cond_sub_func_graph_->set_return(return_cnode); + + // hardcode subgraph outputs name + cond_sub_func_graph_->output()->cast()->set_fullname_with_scope(cond_subgraph_name_ + "_output_0_cnode"); + return RET_OK; +} + +STATUS FunctionalizeWhile::BuildCondGraph() { + cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond"; + cond_sub_func_graph_ = + FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF); + if (cond_sub_func_graph_ == nullptr) { + MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr"; + return RET_NULL_PTR; + } + cond_sub_func_graph_->set_manager(fg_->manager()); + + int ret = CondSubgraphAddNodes(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "add cond_subgraph node failed, ret:" << ret; + return ret; + } + ret = IdentifyCondSubgraphOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "identify cond_subgraph output failed, ret:" << ret; + return ret; + } + ret = IdentifyCondSubgraphInput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "identify cond_subgraph input failed, ret:" << ret; + return ret; + } + + return ret; +} + +// nodes between next_iteration op and switch op will be added into body_func_graph +STATUS FunctionalizeWhile::BodySubgraphAddNodes() { + std::deque todo(512); + todo.clear(); + for (auto &node : node_cluster_) { + if (FunctionalizeControlOpPass::IsNextIteration(node)) { + auto next_iteration_cnode = node->cast(); + for (size_t i = 1; i < next_iteration_cnode->inputs().size(); i++) { + todo.push_back(next_iteration_cnode->input(i)); + } + body_subgraph_output_map_[node] = next_iteration_cnode->input(1); + } + } + + while (!todo.empty()) { + AnfNodePtr node = todo.back(); + todo.pop_back(); + if (FunctionalizeControlOpPass::IsSwitch(node)) { + continue; + } + if (utils::isa(node)) { + body_sub_func_graph_->add_parameter(node->cast()); + } else { + body_sub_func_graph_->AddNode(node); + } + node->set_func_graph(body_sub_func_graph_); + if (utils::isa(node)) { + auto cnode = utils::cast(node); + for (size_t i = 1; i < cnode->inputs().size(); i++) { + todo.push_back(cnode->input(i)); + } + } + } + return RET_OK; +} + +STATUS FunctionalizeWhile::IdentifyBodySubgraphInput() { + std::vector nodes_need_drop{}; + for (auto &cnode : body_sub_func_graph_->GetOrderedCnodes()) { + for (auto &input_node : cnode->inputs()) { + if (FunctionalizeControlOpPass::IsSwitch(input_node)) { + auto switch_node = input_node->cast(); + auto merge_node = BlongToWhichMerge(switch_node); + auto enter_node = BlongToWhichEnter(merge_node); + int pos = PosInInputEnterNodes(enter_node); + nodes_need_drop.push_back(cnode); + + // set parameter + auto parameter = body_sub_func_graph_->add_parameter(); + parameter->set_abstract(cnode->abstract()); + // hardcode for subgraph input name + parameter->set_name(body_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter"); + + // replace switch + auto manager = fg_->manager(); + auto node_users = manager->node_users()[cnode]; + for (auto &node_user : node_users) { + if (body_sub_func_graph_->nodes().contains(node_user.first)) { + manager->SetEdge(node_user.first, node_user.second, parameter); + } + } + } + } + } + + // drop node from cond_func_graph + for (const auto &node : nodes_need_drop) { + body_sub_func_graph_->DropNode(node); + } + return RET_OK; +} + +STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() { + std::vector tmp_output{}; + tmp_output.resize(input_enter_nodes_.size()); + // next_iteration -> switch -> merge -> enter + for (auto &node_pair : body_subgraph_output_map_) { + auto next_iteration_cnode = utils::cast(node_pair.first); + auto switch_node = BlongToWhichSwitch(next_iteration_cnode); + auto merge_node = BlongToWhichMerge(switch_node); + auto enter_node = BlongToWhichEnter(merge_node); + int pos = PosInInputEnterNodes(enter_node); + + tmp_output[pos] = node_pair.second; + // hard code. set cnode output name + node_pair.second->cast()->set_fullname_with_scope(body_subgraph_name_ + "_output_" + std::to_string(pos) + + "_cnode"); + } + + auto return_prim_ptr = lite::GetReturnPrim(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return RET_NULL_PTR; + } + auto value_node = NewValueNode(return_prim_ptr); + // cond subgraph output is LoopCond's input + std::vector op_inputs{value_node}; + auto return_cnode = body_sub_func_graph_->NewCNode(op_inputs); + return_cnode->set_fullname_with_scope(body_subgraph_name_ + "-return"); + + if (tmp_output.size() == 1) { + return_cnode->add_input(tmp_output[0]); + } else { + std::vector make_tuple_inputs = tmp_output; + auto make_tuple_prim_ptr = lite::GetMakeTuplePrim(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; + return RET_NULL_PTR; + } + auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); + make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); + auto make_tuple_cnode = body_sub_func_graph_->NewCNode(make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope(return_cnode->fullname_with_scope() + "tuple"); + + return_cnode->add_input(make_tuple_cnode); + } + + body_sub_func_graph_->set_return(return_cnode); + return RET_OK; +} + +STATUS FunctionalizeWhile::BuildBodyGraph() { + body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body"; + body_sub_func_graph_ = + FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF); + if (body_sub_func_graph_ == nullptr) { + MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr"; + return RET_NULL_PTR; + } + body_sub_func_graph_->set_manager(fg_->manager()); + + int ret = BodySubgraphAddNodes(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "add body_subgraph node failed, ret:" << ret; + return ret; + } + ret = IdentifyBodySubgraphOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "identify body_subgraph output failed, ret:" << ret; + return ret; + } + ret = IdentifyBodySubgraphInput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "identify body_subgraph input failed, ret:" << ret; + return ret; + } + return ret; +} + +STATUS FunctionalizeWhile::InsertFuncGraphToWhileInput() { + // set while input cond and body vnode + auto cond_value_node = NewValueNode(cond_sub_func_graph_); + auto body_value_node = NewValueNode(body_sub_func_graph_); + auto inputs = while_node_->inputs(); + inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node}); + while_node_->set_inputs(inputs); + return RET_OK; +} + +STATUS FunctionalizeWhile::DropUselessNodesInMainGraph() { + // fg_ drop cluster node + for (auto &node : node_cluster_) { + fg_->DropNode(node); + } + return RET_OK; +} + +STATUS FunctionalizeWhile::Process() { + int ret = BuildWhileNode(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "build while node failed, ret:" << ret; + return ret; + } + + ret = BuildCondGraph(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "build while node failed, ret:" << ret; + return ret; + } + + ret = BuildBodyGraph(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "build while node failed, ret:" << ret; + return ret; + } + + ret = InsertFuncGraphToWhileInput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "insert func_graph to while input failed, ret:" << ret; + return ret; + } + + ret = DropUselessNodesInMainGraph(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "main func_graph drop nodes failed, ret:" << ret; + return ret; + } + return ret; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_while.h b/mindspore/lite/tools/optimizer/graph/functionalize_while.h new file mode 100644 index 0000000000..47bd5769db --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/functionalize_while.h @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_ +#define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_ +#include +#include +#include +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "tools/optimizer/graph/functionalize_control_op_pass.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { + +class FunctionalizeWhile { + public: + FunctionalizeWhile(std::vector node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg) + : node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {} + + // while + STATUS BuildWhileNode(); + STATUS IdentifyWhileNodeInput(); + STATUS IdentifyWhileNodeOutput(); + STATUS UpdateExitNodeUser(); + STATUS NewWhileNode(); + STATUS InsertFuncGraphToWhileInput(); + + // cond subgraph + STATUS BuildCondGraph(); + STATUS CondSubgraphAddNodes(); + STATUS IdentifyCondSubgraphInput(); + STATUS IdentifyCondSubgraphOutput(); + + // body subgraph + STATUS BuildBodyGraph(); + STATUS BodySubgraphAddNodes(); + STATUS IdentifyBodySubgraphInput(); + STATUS IdentifyBodySubgraphOutput(); + + CNodePtr BlongToWhichSwitch(const CNodePtr &node); + CNodePtr BlongToWhichMerge(const CNodePtr &node); + CNodePtr BlongToWhichEnter(const CNodePtr &node); + int PosInInputEnterNodes(const CNodePtr &node); + STATUS DropUselessNodesInMainGraph(); + + STATUS Process(); + + private: + std::vector node_cluster_{}; + const CNodePtr loop_cond_node_; + FuncGraphPtr fg_; + + FuncGraphPtr cond_sub_func_graph_ = nullptr; + FuncGraphPtr body_sub_func_graph_ = nullptr; + CNodePtr while_node_ = nullptr; + + std::string cond_subgraph_name_{}; + std::string body_subgraph_name_{}; + + // while + std::vector input_enter_nodes_{}; + std::vector output_exit_nodes_{}; + + // pair (next iteration node, next iteration node input) + std::map body_subgraph_output_map_{}; + // pair (switch node, switch output in body graph) + std::map body_subgraph_input_map_{}; + // pair (switch node, switch output in body graph) + std::map cond_subgraph_input_map_{}; +}; + +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index fb8b84cddf..8504626d10 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -325,26 +325,6 @@ STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const Fu return RET_OK; } -STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) { - auto body_partial_cnode = switch_cnode->input(2)->cast(); - MS_ASSERT(body_partial_cnode != nullptr); - auto body_vnode = body_partial_cnode->input(0)->cast(); - MS_ASSERT(body_vnode != nullptr); - auto body_fg = GetValueNode(body_vnode); - MS_ASSERT(body_fg != nullptr); - AbstractBasePtrList abstract_list; - auto body_fg_output_cnode = utils::cast(body_fg->output()); - for (auto &cnode : body_fg_output_cnode->inputs()) { - if (!utils::isa(cnode) && !utils::isa(cnode)) { - continue; - } - abstract_list.push_back(cnode->abstract()); - } - - switch_cnode->set_abstract(std::make_shared(abstract_list)); - return RET_OK; -} - bool InferShapePass::Run(const FuncGraphPtr &func_graph) { if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { MS_LOG(INFO) << "The framework type of model should be tf/tflite."; @@ -384,14 +364,6 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { } auto type = GetCNodeType(cnode); - if (type == schema::PrimitiveType_Switch) { - int ret = SwitchCNodeInferShape(cnode); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PartialCNodeInferShape failed."; - return false; - } - } - if ((type == schema::PrimitiveType_TupleGetItem) || #ifdef SUPPORT_TRAIN (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h index d0955079c9..b316430f7d 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -41,7 +41,6 @@ class InferShapePass : public Pass { STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *output_tensors); STATUS SetParameterAbstract(const ParameterPtr ¶meter); STATUS SetCNodeAbstract(const std::vector &output_tensors, const std::shared_ptr &cnode); - STATUS SwitchCNodeInferShape(const CNodePtr &cnode); int StrIsContain(const std::vector &total, const std::string &aim); int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph);