!4442 fix conv_action_fusion bug

Merge pull request !4442 from zhengjun10/master
pull/4442/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 8da3c54c8b

@ -48,4 +48,6 @@ scan_hms_angle1.tflite
hiai_latin_ocr.tflite
hiai_latin_ocr_1.tflite
ml_ocr_jk.tflite
nasnet_mobile.tflite
nasnet_large.tflite
inception_resnet_v2.tflite

@ -15,6 +15,8 @@
*/
#include "tools/optimizer/common/gllo_utils.h"
#include <vector>
#include <algorithm>
#include <utility>
#include "src/ir/primitive_t_value.h"
#include "frontend/operator/ops.h"
#include "backend/optimizer/common/helper.h"
@ -367,5 +369,29 @@ size_t GetOutputTensorNum(const AnfNodePtr &node) {
return 1;
}
}
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto output_node_list = GetRealNodeUsedList(graph, node);
if (output_node_list->size() != 1) {
MS_LOG(DEBUG) << "fusion node has multi output nodes";
return true;
}
return false;
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
auto output_info_list = iter->second;
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
return output_node_list;
}
} // namespace opt
} // namespace mindspore

@ -59,6 +59,8 @@ bool IsConvNode(const BaseRef &n);
bool CheckIsAllInputsParam(const AnfNodePtr &node);
size_t GetOutputTensorNum(const AnfNodePtr &node);
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_

@ -51,6 +51,9 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
AnfNodePtr pre_node = act_node->input(1);
CheckIfAnfNodeIsNull(pre_node);
if (pre_node != nullptr && pre_node->isa<CNode>()) {
if (IsMultiOutputTensors(func_graph, pre_node)) {
return node;
}
auto conv_node = pre_node->cast<CNodePtr>();
auto node_type = GetCNodeType(conv_node);
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));

@ -138,6 +138,9 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
AnfNodePtr conv_node_anf = add_node->input(1);
CheckIfAnfNodeIsNull(conv_node_anf);
if (IsMultiOutputTensors(func_graph, conv_node_anf)) {
return add_node;
}
auto conv_node = conv_node_anf->cast<CNodePtr>();
CheckIfCNodeIsNull(conv_node);
GenConvNewBias(func_graph, conv_node, add_node);

@ -66,7 +66,9 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
auto pre_node = transform_node->input(1);
auto conv_node = pre_node->cast<CNodePtr>();
if (IsMultiOutputTensors(func_graph, conv_node)) {
return transform_node;
}
int kernel_nums = Get_Kenrnel_nums(conv_node);
if (kernel_nums <= 0) {
MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString();

Loading…
Cancel
Save