From 80ed8e0e5c1bfce99fe0d461e52eacfacf73c06d Mon Sep 17 00:00:00 2001 From: VectorSL Date: Sat, 18 Jul 2020 17:20:17 +0800 Subject: [PATCH] fix gpu cast fusion bug --- .../optimizer/gpu/replace_bn_cast_fusion.cc | 51 ++++++++--------- .../gpu/replace_bn_grad_cast_fusion.cc | 55 +++++++++++-------- 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc index 8e90f044fa..2d48e5b002 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc @@ -30,8 +30,7 @@ const BaseRef ReplaceBNCastFusion::DefinePattern() const { VectorRef in_cast = VectorRef({prim::kPrimCast, x_}); VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_}); VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_}); - VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget}); - return out_cast; + return tupleget; } const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, @@ -40,19 +39,9 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); - auto tuple = AnfAlgo::GetInputNode(utils::cast(node), 0); - auto index_node = AnfAlgo::GetInputNode(utils::cast(tuple), 1); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - - auto fbn2 = AnfAlgo::GetInputNode(utils::cast(tuple), 0); + auto fbn2 = AnfAlgo::GetInputNode(utils::cast(node), 0); auto x_after = AnfAlgo::GetInputNode(utils::cast(fbn2), 0); auto x_before = AnfAlgo::GetInputNode(utils::cast(x_after), 0); - if (item_idx != 0) { - return nullptr; - } auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2), 1); auto bias = AnfAlgo::GetInputNode(utils::cast(fbn2), 2); auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2), 3); @@ -65,14 +54,32 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A MS_EXCEPTION_IF_NULL(bias); MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(var); - + std::vector outputs_type; + std::vector> outputs_shape; auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - manager->Replace(utils::cast(x_after), utils::cast(x_before)); - manager->Replace(utils::cast(node), utils::cast(tuple)); - std::vector outputs_type; - std::vector> outputs_shape; + auto outlist = GetRealNodeUsedList(graph, fbn2); + for (size_t i = 0; i < outlist->size(); i++) { + auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(i).first), 1); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx == 0) { + auto cast = GetRealNodeUsedList(graph, outlist->at(i).first); + if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { + return nullptr; + } + manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(i).first)); + outputs_type.push_back(kNumberTypeFloat16); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get()); + } + } + + manager->Replace(utils::cast(x_after), utils::cast(x_before)); + outputs_type.clear(); + outputs_shape.clear(); auto output_num = AnfAlgo::GetOutputTensorNum(fbn2); for (size_t i = 0; i < output_num; i++) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i)); @@ -80,13 +87,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A } outputs_type[0] = kNumberTypeFloat16; AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get()); - - outputs_type.clear(); - outputs_shape.clear(); - outputs_type.push_back(kNumberTypeFloat16); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0)); - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get()); - return tuple; + return node; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc index 9dba16bf86..37bb0d96ad 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc @@ -30,8 +30,7 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_}); VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_}); VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); - VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget}); - return out_cast; + return tupleget; } const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, @@ -40,21 +39,16 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); - auto tuple = AnfAlgo::GetInputNode(utils::cast(node), 0); - auto index_node = AnfAlgo::GetInputNode(utils::cast(tuple), 1); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - if (item_idx != 0) { - return nullptr; - } - auto fbn2g = AnfAlgo::GetInputNode(utils::cast(tuple), 0); + auto fbn2g = AnfAlgo::GetInputNode(utils::cast(node), 0); auto dy_after = AnfAlgo::GetInputNode(utils::cast(fbn2g), 0); auto dy_before = AnfAlgo::GetInputNode(utils::cast(dy_after), 0); auto x_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 1); - + auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0); + // if x_type is fp32, the cast is nessery. + if (x_type == kNumberTypeFloat32) { + return nullptr; + } auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2g), 2); auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2g), 3); auto var = AnfAlgo::GetInputNode(utils::cast(fbn2g), 4); @@ -66,13 +60,32 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con MS_EXCEPTION_IF_NULL(x_); MS_EXCEPTION_IF_NULL(mean); MS_EXCEPTION_IF_NULL(var); - + std::vector outputs_type; + std::vector> outputs_shape; auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); + + auto outlist = GetRealNodeUsedList(graph, fbn2g); + for (size_t i = 0; i < outlist->size(); i++) { + auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(i).first), 1); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx == 0) { + auto cast = GetRealNodeUsedList(graph, outlist->at(i).first); + if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { + return nullptr; + } + manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(i).first)); + outputs_type.push_back(kNumberTypeFloat16); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get()); + } + } + outputs_type.clear(); + outputs_shape.clear(); manager->Replace(utils::cast(dy_after), utils::cast(dy_before)); - manager->Replace(utils::cast(node), utils::cast(tuple)); - std::vector outputs_type; - std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g); for (size_t i = 0; i < output_num; i++) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i)); @@ -80,12 +93,8 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con } outputs_type[0] = kNumberTypeFloat16; AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get()); - outputs_type.clear(); - outputs_shape.clear(); - outputs_type.push_back(kNumberTypeFloat16); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0)); - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get()); - return tuple; + + return node; } } // namespace opt } // namespace mindspore