!13309 add functionalize_cond & tf_bidirection_gru_cf_fusion

From: @wangzhe128
Reviewed-by: 
Signed-off-by:
pull/13309/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bc7db669cb

@ -240,13 +240,15 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc
@ -258,6 +260,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/if_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc
${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc
)

@ -61,3 +61,4 @@ ml_noya_tts_melgan.pb 1;16,16,80
ml_video_edit_oneclick_adaptis.pb 3
# Q_hand_0812.pb is not suitable for float16. Out of float16 range.
Q_hand_0812.pb
tacotron_encoder_stf.pb 5;1:1,62:1,62:1,62:1,62

@ -50,13 +50,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_conv_fusion.cc
../optimizer/fusion/tflite_lstm_cell_fusion.cc
../optimizer/fusion/tf_lstm_cell_fusion.cc
../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/group_depthwise_op_convert_pass.cc
../optimizer/graph/tflite_inputs_adjust_pass.cc
../optimizer/graph/update_conv2d_param_pass.cc
../optimizer/graph/unused_node_remove_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/redundant_op_remove_pass.cc
@ -68,6 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/if_pass.cc
../optimizer/graph/functionalize_control_op_pass.cc
../optimizer/graph/functionalize_while.cc
../optimizer/graph/functionalize_cond.cc
../optimizer/graph/inputs_adjust_pass.cc
../optimizer/graph/primitive_adjust_pass.cc
)

@ -18,6 +18,8 @@
#include <memory>
#include <string>
#include "src/common/log_adapter.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/ir/primitive.h"
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
#include "tools/optimizer/fusion/conv_activation_fusion.h"
#include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
@ -31,7 +33,8 @@
#include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
#include "tools/optimizer/graph/primitive_adjust_pass.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
@ -42,6 +45,7 @@
#include "tools/optimizer/graph/tflite_inputs_adjust_pass.h"
#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h"
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
#include "tools/optimizer/graph/unused_node_remove_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
@ -81,7 +85,7 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
}
if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
@ -225,6 +229,23 @@ int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter
return RET_OK;
}
int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config) {
MS_ASSERT(old_graph != nullptr);
auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false);
// fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
asylic_pm->AddPass(std::make_shared<opt::TfBidirectionGruCfFusion>());
// remove remaining cyclic nodes
asylic_pm->AddPass(std::make_shared<opt::UnusedNodeRemovePass>());
asylic_optimizer->AddPassManager(asylic_pm);
if (!asylic_optimizer->Optimize(old_graph)) {
MS_LOG(ERROR) << "gru cf fusion pass failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config,
const FuncGraphPtr &new_graph) {
// quant
@ -266,7 +287,13 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return old_graph;
}
auto status = RunAdjustPass(old_graph, config);
auto status = RunPrecedingPass(old_graph, *config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Preceding pass failed.";
return nullptr;
}
status = RunAdjustPass(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Adjust pass failed.";
return nullptr;

@ -50,6 +50,8 @@ class AnfTransform {
static int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);
static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

@ -1,37 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace lite {
constexpr auto kNameIf = "If";
class If : public PrimitiveC {
public:
If() : PrimitiveC(kNameIf) {}
~If() = default;
MS_DECLARE_PARENT(If, PrimitiveC);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_

@ -1,39 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace lite {
constexpr auto kNameLoopCond = "LoopCond";
class LoopCond : public PrimitiveC {
public:
LoopCond() : PrimitiveC(kNameLoopCond) {}
~LoopCond() = default;
MS_DECLARE_PARENT(LoopCond, PrimitiveC);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_

@ -17,16 +17,31 @@
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#include "schema/inner/model_generated.h"
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace lite {
#define ADD_CONVERTER_ONLY_OP(name) \
constexpr auto kName##name = #name; \
class name : public PrimitiveC { \
public: \
name() : PrimitiveC(kName##name) {} \
~name() = default; \
MS_DECLARE_PARENT(name, PrimitiveC); \
};
enum ConverterPrimitiveType {
ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1,
ConverterPrimitiveType_LoopCond,
ConverterPrimitiveType_NextIteration,
ConverterPrimitiveType_Exit,
};
ADD_CONVERTER_ONLY_OP(Enter);
ADD_CONVERTER_ONLY_OP(Exit);
ADD_CONVERTER_ONLY_OP(If);
ADD_CONVERTER_ONLY_OP(LoopCond);
ADD_CONVERTER_ONLY_OP(NextIteration);
ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3);
ADD_CONVERTER_ONLY_OP(TensorArrayReadV3);
ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3);
ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
} // namespace lite
} // namespace mindspore

@ -17,7 +17,7 @@
#include "tools/converter/parser/onnx/onnx_if_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "tools/converter/ops/if.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

@ -19,7 +19,7 @@
#include <vector>
#include "tools/converter/parser/tf/tf_enter_parser.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/enter.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

@ -18,7 +18,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/exit.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

@ -19,7 +19,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/if.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

@ -18,7 +18,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/loop_cond.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

@ -28,7 +28,7 @@ ops::PrimitiveC *TFMergeParser::Parse(const tensorflow::NodeDef &tf_op,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::Merge>();
*output_size = tf_op.input_size();
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

@ -18,7 +18,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/next_iteration.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

@ -28,7 +28,7 @@ ops::PrimitiveC *TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::Switch>();
*output_size = tf_op.input_size();
*output_size = 2;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_gather_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayGatherParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayGatherParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayGatherV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayGatherParser("TensorArrayGatherV3", new TFTensorArrayGatherParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFTensorArrayGatherParser : public TFNodeParser {
public:
TFTensorArrayGatherParser() = default;
~TFTensorArrayGatherParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 2;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayParser("TensorArrayV3", new TFTensorArrayParser());
} // namespace lite
} // namespace mindspore

@ -13,27 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
constexpr auto kNameEnter = "Enter";
class Enter : public PrimitiveC {
class TFTensorArrayParser : public TFNodeParser {
public:
Enter() : PrimitiveC(kNameEnter) {}
~Enter() = default;
MS_DECLARE_PARENT(Enter, PrimitiveC);
TFTensorArrayParser() = default;
~TFTensorArrayParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_read_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayReadParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayReadParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayReadV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayReadParser("TensorArrayReadV3", new TFTensorArrayReadParser());
} // namespace lite
} // namespace mindspore

@ -13,27 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
constexpr auto kNameExit = "Exit";
class Exit : public PrimitiveC {
class TFTensorArrayReadParser : public TFNodeParser {
public:
Exit() : PrimitiveC(kNameExit) {}
~Exit() = default;
MS_DECLARE_PARENT(Exit, PrimitiveC);
TFTensorArrayReadParser() = default;
~TFTensorArrayReadParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_scatter_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayScatterParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayScatterParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayScatterV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayScatterParser("TensorArrayScatterV3", new TFTensorArrayScatterParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFTensorArrayScatterParser : public TFNodeParser {
public:
TFTensorArrayScatterParser() = default;
~TFTensorArrayScatterParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_size_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArraySizeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArraySizeParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArraySizeV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArraySizeParser("TensorArraySizeV3", new TFTensorArraySizeParser());
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save