functionalize while

pull/11324/head
mengyuanli 4 years ago
parent 37f00fdaab
commit 0f426c434f

@ -1103,10 +1103,13 @@ schema::QuantType PrimitiveC::quant_type() const { return quant_type_; }
#endif
int PrimitiveC::Type() const {
if (this->primitive_ == nullptr) {
if (this->primitive_ == nullptr && this->op_type_ == OP_TYPE_NOT_SET) {
return schema::PrimitiveType_NONE;
}
#ifdef PRIMITIVE_WRITEABLE
if (op_type_ != OP_TYPE_NOT_SET) {
return op_type_;
}
return this->primitive_->value.type;
#else
return this->primitive_->value_type();

@ -24,6 +24,9 @@
#ifdef PRIMITIVE_WRITEABLE
#include "ir/primitive.h"
#include "schema/inner/model_generated.h"
#include "schema/inner/ops_generated.h"
#include "schema/ops_generated.h"
#include "tools/converter/ops/ops_def.h"
#else
#include "schema/model_generated.h"
#endif
@ -34,6 +37,7 @@
namespace mindspore {
namespace lite {
constexpr const int OP_TYPE_NOT_SET = -1;
constexpr uint32_t kSingleNum = 1;
constexpr uint32_t kDoubleNum = 2;
constexpr uint32_t kMultiNum = 3;
@ -149,6 +153,7 @@ class PrimitiveC : public mindspore::Primitive {
std::vector<std::vector<schema::QuantParamT>> output_quant_param_;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
bool infer_flag_ = true;
int op_type_ = OP_TYPE_NOT_SET;
};
std::shared_ptr<PrimitiveC> GetReturnPrim();
@ -227,6 +232,7 @@ class PrimitiveC {
char *primitive_buf_ = nullptr;
bool infer_flag_ = true;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
int op_type_ = OP_TYPE_NOT_SET;
};
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);

@ -2,6 +2,7 @@ set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
set(TEST_DIR ${TOP_DIR}/mindspore/lite/test)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc)
set(CONVERTER_DIR ${TOP_DIR}/mindspore/lite/tools/converter)
include_directories(${TOP_DIR})
include_directories(${TEST_DIR})
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/external_libs/gtest.cmake)
@ -16,7 +17,7 @@ set(CCSRC_SRC
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
)
else(ENABLE_CONVERTER)
else()
set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc)
add_compile_definitions(USE_ANDROID_LOG)
endif()
@ -38,10 +39,10 @@ file(GLOB KERNEL_OP_TRAIN_SRC
${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc
)
if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
endif()
if (PLATFORM_ARM64)
if(PLATFORM_ARM64)
# assembly
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/assembly/arm64/*.s
${LITE_DIR}/nnacl/assembly/arm64/*.S)
@ -53,7 +54,7 @@ if (PLATFORM_ARM64)
)
endif()
if (PLATFORM_ARM32)
if(PLATFORM_ARM32)
# assembly
file(GLOB TEST_ASSEMBLY_SRC
${LITE_DIR}/nnacl/assembly/arm32/*.S
@ -65,7 +66,7 @@ if (PLATFORM_ARM32)
)
endif()
if ("${X86_64_SIMD}" STREQUAL "sse")
if("${X86_64_SIMD}" STREQUAL "sse")
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c)
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_OP_SRC
@ -74,12 +75,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
)
endif()
if ("${X86_64_SIMD}" STREQUAL "avx")
if("${X86_64_SIMD}" STREQUAL "avx")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2")
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c
${LITE_DIR}/nnacl/x86_64_avx/*.c
${LITE_DIR}/nnacl/assembly/avx/*.S)
${LITE_DIR}/nnacl/x86_64_avx/*.c
${LITE_DIR}/nnacl/assembly/avx/*.S)
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
@ -88,7 +89,7 @@ if ("${X86_64_SIMD}" STREQUAL "avx")
endif()
### gpu kernel
if (SUPPORT_GPU)
if(SUPPORT_GPU)
file(GLOB GPU_KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/opencl/kernel/*.cc
)
@ -102,14 +103,18 @@ if (SUPPORT_GPU)
)
endif()
if (PLATFORM_ARM32 OR PLATFORM_ARM64)
if (ENABLE_CONVERTER)
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
if(ENABLE_CONVERTER)
set(BUILD_MINDDATA "off")
endif()
endif()
### runtime framework
add_definitions(-DENABLE_V0)
file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc)
if(ENABLE_CONVERTER)
file(GLOB_RECURSE CONVERTER_OPS_SRC ${CONVERTER_DIR}/ops/*.cc)
set(OPS_SRC ${OPS_SRC} ${CONVERTER_OPS_SRC})
endif()
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${CCSRC_SRC}
@ -144,7 +149,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/errorcode.cc
)
### gpu runtime
if (SUPPORT_GPU)
if(SUPPORT_GPU)
include_directories(${TOP_DIR}/third_party/OpenCL-Headers)
include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include)
set(OPENCL_RUNTIME_SRC
@ -210,13 +215,14 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/while_pass.cc
${LITE_DIR}/tools/optimizer/graph/if_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc
)
endif()
### train
if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
# ${LITE_DIR}/src/train/ops/train_ops.cc
${LITE_DIR}/src/train/train_populate_parameter.cc
${LITE_DIR}/src/train/train_session.cc
${LITE_DIR}/src/train/train_model.cc
@ -251,7 +257,7 @@ set(TEST_SRC
${TEST_DIR}/ut/src/scheduler_test.cc
)
if (ENABLE_CONVERTER)
if(ENABLE_CONVERTER)
set(TEST_SRC
${TEST_SRC}
${TEST_DIR}/st/converter_test.cc
@ -265,7 +271,7 @@ if (ENABLE_CONVERTER)
)
endif()
if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
set(TEST_SRC
${TEST_SRC}
${TEST_CASE_KERNEL_TRAIN_SRC}
@ -278,7 +284,7 @@ else()
)
endif()
if (SUPPORT_GPU)
if(SUPPORT_GPU)
file(GLOB_RECURSE TEST_CASE_KERNEL_GPU_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
)
@ -288,7 +294,7 @@ if (SUPPORT_GPU)
)
endif()
if (ENABLE_FP16)
if(ENABLE_FP16)
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc
)
@ -296,24 +302,24 @@ if (ENABLE_FP16)
${TEST_SRC}
${TEST_CASE_KERNEL_FP16_SRC}
)
endif ()
endif()
add_executable(lite-test ${TEST_SRC})
add_dependencies(lite-test fbs_src)
target_link_libraries(lite-test dl mindspore::gtest)
if (PLATFORM_ARM64 AND ENABLE_FP16)
if(PLATFORM_ARM64 AND ENABLE_FP16)
target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid)
endif()
if (PLATFORM_ARM)
if(PLATFORM_ARM)
target_link_libraries(lite-test log)
endif()
if (SUPPORT_NPU)
if(SUPPORT_NPU)
include_directories(${DDK_PATH})
target_link_libraries(lite-test npu_kernel_mid)
endif ()
if (ENABLE_CONVERTER)
endif()
if(ENABLE_CONVERTER)
add_dependencies(lite-test fbs_inner_src)
target_link_libraries(lite-test
anf_importer_mid

@ -12,6 +12,8 @@ include(${TOP_DIR}/cmake/external_libs/glog.cmake)
file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc)
file(GLOB CONVERTER_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/converter.cc
@ -65,6 +67,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/while_pass.cc
../optimizer/graph/if_pass.cc
../optimizer/graph/mindir_inputs_adjust_pass.cc
../optimizer/graph/functionalize_control_op_pass.cc
../optimizer/graph/functionalize_while.cc
)
add_subdirectory(../anf_importer anf_importer)
@ -97,12 +101,12 @@ set(LITE_SRC
${SRC_DIR}/errorcode.cc
${SRC_DIR}/dequant.cc
)
if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
set(LITE_SRC
${LITE_SRC}
)
endif ()
endif()
set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm)
file(GLOB KERNEL_SRC
${ARM_DIR}/base/*.cc
@ -114,13 +118,13 @@ file(GLOB KERNEL_SRC
${ARM_DIR}/int8/*.cc
)
if (PLATFORM_ARM64)
if(PLATFORM_ARM64)
# assembly
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.s
${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.S)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif ()
endif()
file(GLOB PROTO_FILE ""
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
@ -133,11 +137,13 @@ add_library(proto_mid OBJECT ${PROTO_SRCS})
set(TFLITE_FBS_FILES
${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/schema.fbs
)
ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src ${CMAKE_BINARY_DIR}/schema "inner")
ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src
${CMAKE_BINARY_DIR}/schema "inner")
set_property(SOURCE ${CONVERTER_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${CCSRC_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${CONVERTER_OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${KERNEL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_executable(converter_lite
@ -145,6 +151,7 @@ add_executable(converter_lite
${CCSRC_SRC}
${CONVERTER_SRC}
${OPS_SRC}
${CONVERTER_OPS_SRC}
${KERNEL_SRC}
${LITE_SRC}
)

@ -48,6 +48,7 @@
#include "tools/optimizer/graph/slice_prepose_pass.h"
#include "tools/optimizer/graph/while_pass.h"
#include "tools/optimizer/graph/if_pass.h"
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -100,6 +101,15 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
}
}
if (config->fmk == lite::converter::FmkType_TF) {
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>();
if (!functionalize_control_op_pass->Run(old_graph)) {
MS_LOG(ERROR) << "functionalize control op pass failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
@ -145,7 +155,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
if (remove_unused_cast_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified";
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
return nullptr;
}
remove_unused_cast_pass->SetFmkType(config->fmk);
@ -154,7 +164,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
if (config->fmk == lite::converter::FmkType_ONNX) {
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
if (remove_unused_transpose_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified";
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified";
return nullptr;
}
remove_unused_transpose_pass->SetFmkType(config->fmk);

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/ops/enter.h"
#include "src/tensorlist.h"
namespace mindspore {
namespace lite {
int Enter::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ENTER_H_
#define LITE_MINDSPORE_LITE_C_OPS_ENTER_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Enter : public PrimitiveC {
public:
Enter() { op_type_ = ConverterPrimitiveType_Enter; }
~Enter() = default;
MS_DECLARE_PARENT(Enter, PrimitiveC);
explicit Enter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ENTER_H_

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/ops/exit.h"
#include "src/tensorlist.h"
namespace mindspore {
namespace lite {
int Exit::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_EXIT_H_
#define LITE_MINDSPORE_LITE_C_OPS_EXIT_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Exit : public PrimitiveC {
public:
Exit() { op_type_ = ConverterPrimitiveType_Exit; }
~Exit() = default;
MS_DECLARE_PARENT(Exit, PrimitiveC);
explicit Exit(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_EXIT_H_

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/ops/loop_cond.h"
#include "src/tensorlist.h"
namespace mindspore {
namespace lite {
int LoopCond::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_
#define LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class LoopCond : public PrimitiveC {
public:
LoopCond() { op_type_ = ConverterPrimitiveType_LoopCond; }
~LoopCond() = default;
MS_DECLARE_PARENT(LoopCond, PrimitiveC);
explicit LoopCond(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/ops/next_iteration.h"
#include "src/tensorlist.h"
namespace mindspore {
namespace lite {
int NextIteration::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_
#define LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class NextIteration : public PrimitiveC {
public:
NextIteration() { op_type_ = ConverterPrimitiveType_NextIteration; }
~NextIteration() = default;
MS_DECLARE_PARENT(NextIteration, PrimitiveC);
explicit NextIteration(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
enum ConverterPrimitiveType {
ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1,
ConverterPrimitiveType_LoopCond,
ConverterPrimitiveType_NextIteration,
ConverterPrimitiveType_Exit,
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_enter_parser.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/enter.h"
namespace mindspore {
namespace lite {
STATUS TFEnterParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF EnterParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
*primitiveC = new (std::nothrow) Enter();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfEnterParser("Enter", new TFEnterParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFEnterParser : public TFNodeParser {
public:
TFEnterParser() = default;
~TFEnterParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_exit_parser.h"
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/exit.h"
namespace mindspore {
namespace lite {
STATUS TFExitParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ExitParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
*primitiveC = new (std::nothrow) Exit();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfExitParser("Exit", new TFExitParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFExitParser : public TFNodeParser {
public:
TFExitParser() = default;
~TFExitParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_loop_cond_parser.h"
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/loop_cond.h"
namespace mindspore {
namespace lite {
STATUS TFLoopCondParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF LoopCondParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
*primitiveC = new (std::nothrow) LoopCond();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfLoopCondParser("LoopCond", new TFLoopCondParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFLoopCondParser : public TFNodeParser {
public:
TFLoopCondParser() = default;
~TFLoopCondParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_merge_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFMergeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF MergeParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::MergeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Merge;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfMergeParser("Merge", new TFMergeParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFMergeParser : public TFNodeParser {
public:
TFMergeParser() = default;
~TFMergeParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

@ -277,7 +277,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co
}
param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t));
} else {
MS_LOG(ERROR) << "Unsupport dataType: " << type;
MS_LOG(ERROR) << "Unsupported dataType: " << type;
return RET_ERROR;
}
@ -417,6 +417,16 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
MS_LOG(ERROR) << "Convert ops failed.";
return nullptr;
}
if (!nodes_with_null_input_.empty()) {
status = ConnectNullInput();
if (status != RET_OK) {
MS_LOG(ERROR) << "Connect null inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
}
status = ConvertRootGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
@ -474,16 +484,16 @@ STATUS TFModelParser::ConvertSubgraph() {
std::vector<ParameterPtr> sub_graph_inputs;
for (int j = 0; j < input_arg_size; j++) {
auto &input_arg = tf_sub_signature.input_arg(j);
auto paramter = sub_func_graph->add_parameter();
paramter->set_name(input_arg.name());
anf_sub_node_map[input_arg.name()] = paramter;
auto parameter = sub_func_graph->add_parameter();
parameter->set_name(input_arg.name());
anf_sub_node_map[input_arg.name()] = parameter;
auto root_inputs = cnode->inputs();
if (op_type == schema::PrimitiveType_While) {
paramter->set_abstract(root_inputs[j + 1]->abstract());
parameter->set_abstract(root_inputs[j + 1]->abstract());
} else {
paramter->set_abstract(root_inputs[j + 2]->abstract());
parameter->set_abstract(root_inputs[j + 2]->abstract());
}
sub_graph_inputs.emplace_back(paramter);
sub_graph_inputs.emplace_back(parameter);
}
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
@ -643,7 +653,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
const std::vector<std::string> &input_names,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
std::vector<AnfNodePtr> *inputs) {
std::vector<AnfNodePtr> *inputs,
std::vector<std::string> *input_name_not_found) {
MS_ASSERT(node_def != nullptr);
// parse inputs
for (size_t j = 0; j < input_names.size(); j++) {
@ -656,8 +667,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
}
auto input = GetAnfNode(flatten_input_name, anf_node_map);
if (input == nullptr) {
MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
return RET_ERROR;
MS_LOG(WARNING) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
(*input_name_not_found).push_back(flatten_input_name);
}
inputs->emplace_back(input);
}
@ -718,6 +729,27 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
return RET_OK;
}
STATUS TFModelParser::RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found) {
nodes_with_null_input_.emplace_back(node, input_name_not_found);
return RET_OK;
}
STATUS TFModelParser::ConnectNullInput() {
for (auto &it : nodes_with_null_input_) {
auto &cnode = it.first;
auto &input_name_not_found = it.second;
auto &inputs = cnode->inputs();
int i = 0;
for (size_t j = 0; j < inputs.size(); ++j) {
if (inputs[j] == nullptr) {
cnode->set_input(j, GetAnfNode(input_name_not_found[i], anf_root_node_map_));
++i;
}
}
}
return RET_OK;
}
STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr,
@ -752,7 +784,8 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
return RET_ERROR;
}
std::vector<AnfNodePtr> inputs = {value_node};
status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs);
std::vector<std::string> input_name_not_found{};
status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs, &input_name_not_found);
if (status != RET_OK) {
return status;
}
@ -787,6 +820,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
}
}
if (!input_name_not_found.empty()) {
RecordNullInput(anf_node, input_name_not_found);
}
status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";

@ -22,6 +22,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "proto/graph.pb.h"
#include "proto/node_def.pb.h"
#include "schema/inner/model_generated.h"
@ -55,7 +56,7 @@ class TFModelParser : public ModelParser {
STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
std::vector<AnfNodePtr> *inputs);
std::vector<AnfNodePtr> *inputs, std::vector<std::string> *input_name_not_found);
STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map, const FuncGraphPtr &anf_graph,
int output_size);
@ -71,6 +72,10 @@ class TFModelParser : public ModelParser {
STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);
STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found);
STATUS ConnectNullInput();
FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
@ -79,6 +84,7 @@ class TFModelParser : public ModelParser {
std::vector<std::string> graph_output_names_;
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
std::map<std::string, AnfNodePtr> function_if_map_; // tf function name->if_node
std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
};
} // namespace lite
} // namespace mindspore

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_next_iteration_parser.h"
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/next_iteration.h"
namespace mindspore {
namespace lite {
STATUS TFNextIterationParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF NextIterationParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
*primitiveC = new (std::nothrow) NextIteration();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfNextIterationParser("NextIteration", new TFNextIterationParser());
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save