Add optimization pass to remove redundant Select, fix uninitiated parameter in test_bert_train.py script

Create a new pass named ValueBasedEliminate to reduce the load of Arithmetic Simplify

Code review
pull/3270/head
Hoai Linh Tran 5 years ago
parent 16079e6356
commit 2861e5462d

@ -41,6 +41,7 @@
#include "frontend/optimizer/irpass/symbol_resolver.h" #include "frontend/optimizer/irpass/symbol_resolver.h"
#include "frontend/optimizer/irpass/tile_eliminate.h" #include "frontend/optimizer/irpass/tile_eliminate.h"
#include "frontend/optimizer/irpass/transpose_eliminate.h" #include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/irpass/value_based_eliminate.h"
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" #include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
@ -165,6 +166,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
sparse_tensor_eliminate_ = MakeSubstitution( sparse_tensor_eliminate_ = MakeSubstitution(
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate", std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
// Value_Based Eliminate
value_based_eliminate_ =
MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", {prim::kPrimSelect});
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {

@ -110,6 +110,9 @@ class OptimizeIRPassLib {
// SparseTensor Eliminate // SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_; SubstitutionPtr sparse_tensor_eliminate_;
// Value_Based Eliminate
SubstitutionPtr value_based_eliminate_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/irpass/value_based_eliminate.h"
namespace mindspore {
namespace opt {
namespace irpass {
bool IsCNodePositive(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
return IsCNodePositive(node->cast<CNodePtr>()->input(1));
}
if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) {
return true;
}
return false;
}
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y, z;
PConstant zero_(node, false, 0);
PConstant zero_scalar_(node, false, 0, true);
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y,
IsCNodePositive(x.GetNode(node)));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y,
IsCNodePositive(x.GetNode(node)));
return nullptr;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/pattern_matcher.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
class ValueBasedEliminate : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_

@ -162,15 +162,10 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
} }
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig({ opt::OptPassConfig b_1 =
irpass.zero_like_fill_zero_, opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.item_tuple_eliminate_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_,
irpass.float_tuple_getitem_switch_, irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_});
irpass.reset_defer_inline_,
irpass.inline_,
irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_,
});
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,

@ -22,14 +22,15 @@ import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn import learning_rate_schedule as lr_schedules from mindspore.nn import learning_rate_schedule as lr_schedules
from mindspore.ops import operations as P
from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from ...dataset_mock import MindData from ...dataset_mock import MindData
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
_current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data" _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data"
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)

Loading…
Cancel
Save