From 2291b7f2e66ec6b002fdd32c06c5245d2abb744b Mon Sep 17 00:00:00 2001 From: wilfChen Date: Fri, 13 Nov 2020 16:23:44 +0800 Subject: [PATCH] dynamic shape check --- .../backend/kernel_compiler/gpu/gpu_kernel.h | 1 - .../gpu/batch_norm_add_relu_grad_fusion.cc | 39 ++++++++++++------- .../ccsrc/backend/session/session_basic.cc | 2 +- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index 54ad8b99ea..863829e55b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -251,7 +251,6 @@ class GpuKernel : public KernelMod { device::DynamicKernelPtr dynamic_kernel_; }; - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc index ae9251c543..d80f865260 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc @@ -110,44 +110,57 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]); return; } -} // namespace -const BaseRef BatchNormAddReluGradFusion::DefinePattern() const { - VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); - VectorRef batch_norm_grad = - VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); - return batch_norm_grad; -} - -const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { +bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format"); MS_EXCEPTION_IF_NULL(format_attr); auto format = GetValue(format_attr); if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { - return nullptr; + return false; } auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); MS_EXCEPTION_IF_NULL(relu_grad); auto relu_users = GetRealNodeUsedList(graph, relu_grad); if (relu_users->size() != 2) { - return nullptr; + return false; } // process pattern as Relu(TensorAdd(BN#0, BN#1)) auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast(node), 5); MS_EXCEPTION_IF_NULL(tuple_getitem); if (!utils::isa(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) { - return nullptr; + return false; } auto forward_node = AnfAlgo::GetInputNode(utils::cast(tuple_getitem), 0); if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) { + return false; + } + + return true; +} +} // namespace + +const BaseRef BatchNormAddReluGradFusion::DefinePattern() const { + VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); + VectorRef batch_norm_grad = + VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); + return batch_norm_grad; +} + +const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + if (!PatternCheck(graph, node)) { return nullptr; } + auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); + MS_EXCEPTION_IF_NULL(relu_grad); auto dy = AnfAlgo::GetInputNode(utils::cast(relu_grad), 0); MS_EXCEPTION_IF_NULL(dy); auto y = AnfAlgo::GetInputNode(utils::cast(relu_grad), 1); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index d507bf56e2..a3f2880355 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1432,7 +1432,7 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector &shape) { - return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; }); + return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; }); } bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {