From eae5f282566d51d3bf2c3ae52fcc84c94d8b722c Mon Sep 17 00:00:00 2001 From: Hoai Linh Tran Date: Thu, 30 Jul 2020 01:25:43 -0400 Subject: [PATCH] Remove redundant Min/Max ops for Bert Update threshold for rounding when checking expected value in input tensor node --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 4 +- .../optimizer/irpass/value_based_eliminate.cc | 80 +++++++++++++++++++ .../optimizer/irpass/value_based_eliminate.h | 2 + mindspore/core/ir/pattern_matcher.h | 4 +- 4 files changed, 86 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 75b54a6905..701bf950ca 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -168,8 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); // Value_Based Eliminate - value_based_eliminate_ = - MakeSubstitution(std::make_shared(), "value_based_eliminate", {prim::kPrimSelect}); + value_based_eliminate_ = MakeSubstitution(std::make_shared(), "value_based_eliminate", + {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); } ResolveIRPassLib::ResolveIRPassLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc index 365859ab4f..38b59afe96 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc @@ -19,6 +19,9 @@ namespace mindspore { namespace opt { namespace irpass { +#define UPPER_FLT_LIMIT (FLT_MAX / 2.0) +#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) + bool IsCNodePositive(const AnfNodePtr &node) { if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { return IsCNodePositive(node->cast()->input(1)); @@ -29,17 +32,94 @@ bool IsCNodePositive(const AnfNodePtr &node) { return false; } +// check if a value is bigger than UPPER_FLT_LIMIT +bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + auto value = value_node->value(); + if (value == nullptr) { + return false; + } + + auto scalar = value->cast(); + if (scalar != nullptr) { + if (scalar->isa()) { + return GetValue(scalar) > UPPER_FLT_LIMIT; + } + } + // Check for Tensor [] or Tensor [1] + auto tensor_ptr = value->cast(); + if (tensor_ptr == nullptr) { + return false; + } + if (tensor_ptr->DataSize() > 1) { + return false; + } + + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data = reinterpret_cast(tensor_ptr->data_c()); + return data[0] > UPPER_FLT_LIMIT; + } + + return false; +} + +// check if a value is smaller than LOWER_FLT_LIMIT +bool IsNodeScalarMinFLT(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + auto value = value_node->value(); + if (value == nullptr) { + return false; + } + + auto scalar = value->cast(); + if (scalar != nullptr) { + if (scalar->isa()) { + return GetValue(scalar) < LOWER_FLT_LIMIT; + } + } + // Check for Tensor [] or Tensor [1] + auto tensor_ptr = value->cast(); + if (tensor_ptr == nullptr) { + return false; + } + if (tensor_ptr->DataSize() > 1) { + return false; + } + + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data = reinterpret_cast(tensor_ptr->data_c()); + return data[0] < LOWER_FLT_LIMIT; + } + + return false; +} + AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { PatternNode x, y, z; PConstant zero_(node, false, 0); PConstant zero_scalar_(node, false, 0, true); + // {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0 MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y, IsCNodePositive(x.GetNode(node))); MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, IsCNodePositive(x.GetNode(node))); + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); + + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); + return nullptr; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h index eca5fb4dbd..3ae2d90a5c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h @@ -32,6 +32,8 @@ namespace opt { namespace irpass { // {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0 +// {prim::kPrimMaximum, X, Y} -> X when Y is smaller than LOWER_FLT_LIMIT +// {prim::kPrimMinimum, X, Y} -> X when Y is greater than UPPER_FLT_LIMIT class ValueBasedEliminate : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 4a7f3c61fd..7c1a856df6 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -487,7 +487,7 @@ class PConstant : public PBase > { TypeId tensor_type = tensor_ptr->Dtype()->type_id(); if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { float *data2 = reinterpret_cast(tensor_ptr->data_c()); - auto threshold = FLT_EPSILON * FLT_EPSILON; + auto threshold = FLT_MIN; for (int i = 0; i < tensor_ptr->DataSize(); i++) { if (fabs(data2[i] - check_value_) > threshold) { return false; @@ -496,7 +496,7 @@ class PConstant : public PBase > { return true; } else if (tensor_type == TypeId::kNumberTypeFloat64) { double *data2 = reinterpret_cast(tensor_ptr->data_c()); - auto threshold = DBL_EPSILON * DBL_EPSILON; + auto threshold = DBL_MIN; for (int i = 0; i < tensor_ptr->DataSize(); i++) { if (fabs(data2[i] - check_value_) > threshold) { return false;