Enable BatchNorm fusion pass

pull/866/head
YuJianfeng 5 years ago
parent 4c32d7e6b8
commit 7185961e89

@ -19,6 +19,7 @@
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ir_fission/bn_split.h"
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h"
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
#include "pre_activate/pass/communication_op_fusion.h"
@ -87,7 +88,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>());
@ -193,8 +193,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
@ -26,29 +27,37 @@ class FusedBatchNormFusion : public PatternProcessPass {
public:
explicit FusedBatchNormFusion(bool multigraph = true)
: PatternProcessPass("fused_batch_norm_fusion", multigraph),
data_input_var0_(std::make_shared<Var>()),
data_input_var1_(std::make_shared<Var>()),
data_input_var2_(std::make_shared<Var>()),
variable_input_var0_(std::make_shared<Var>()),
variable_input_var1_(std::make_shared<Var>()),
constant_input_var0_(std::make_shared<Var>()),
constant_input_var1_(std::make_shared<Var>()) {}
data_input0_var_(std::make_shared<Var>()),
data_input1_var_(std::make_shared<Var>()),
data_input2_var_(std::make_shared<Var>()),
variable_input0_var_(std::make_shared<Var>()),
variable_input1_var_(std::make_shared<Var>()),
constant_input0_var_(std::make_shared<Var>()),
constant_input1_var_(std::make_shared<Var>()),
batch_norm_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()))) {}
~FusedBatchNormFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
abstract::AbstractTuplePtr CreateAbstractOfFusedBatchNorm(const EquivPtr &equiv, const AnfNodePtr &bn) const;
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const;
void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector<AnfNodePtr> &bn_training_reduce_outputs,
std::vector<AnfNodePtr> *bn_training_update_inputs) const;
void GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn,
std::vector<AbstractBasePtr> *abstract_list) const;
AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
ValuePtr GetFactor(const EquivPtr &equiv) const;
VarPtr data_input_var0_;
VarPtr data_input_var1_;
VarPtr data_input_var2_;
VarPtr variable_input_var0_;
VarPtr variable_input_var1_;
VarPtr constant_input_var0_;
VarPtr constant_input_var1_;
VarPtr data_input0_var_;
VarPtr data_input1_var_;
VarPtr data_input2_var_;
VarPtr variable_input0_var_;
VarPtr variable_input1_var_;
VarPtr constant_input0_var_;
VarPtr constant_input1_var_;
VarPtr batch_norm_var_;
};
} // namespace opt
} // namespace mindspore

@ -62,6 +62,7 @@ class _BatchNorm(Cell):
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group = check_int_positive(device_num_each_group)
self.is_global = False
if self.group != 1:
self.rank_id = get_rank()
self.rank_size = get_group_size()
@ -80,15 +81,18 @@ class _BatchNorm(Cell):
self.cast = P.Cast()
self.dtype = P.DType()
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
if context.get_context("enable_ge"):
self.is_ge_backend = True
self.momentum = Tensor(1.0 - momentum, mstype.float32)
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
else:
self.is_ge_backend = False
self.momentum = 1.0 - momentum
if self.is_ge_backend or self.is_ascend:
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
else:
self.bn_train = P.FusedBatchNorm(mode=1,
epsilon=self.eps,
momentum=self.momentum)
@ -140,24 +144,23 @@ class _BatchNorm(Cell):
def construct(self, x):
if self.training and self.use_batch_statistics:
if self.is_ge_backend:
if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
else:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
elif self.is_ge_backend or self.is_ascend:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y = self.bn_train(x,
self.gamma,

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
namespace mindspore {
namespace opt {
class TestHWFusedBatchNormFusion : public BackendCommon {
public:
TestHWFusedBatchNormFusion() : get_py_fun_("gtest_input.pre_activate.fused_batch_norm_fusion_test", true) {}
~TestHWFusedBatchNormFusion() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
std::vector<int> shp_y{64};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract};
for (size_t i = 0; i < 6; ++i) {
args_spec_list.push_back(y_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::FusedBatchNormFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

@ -24,7 +24,8 @@ make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
depend = Primitive('depend')
BatchNorm = P.BatchNorm()
FusedBatchNorm = P.FusedBatchNorm()
BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdate = Primitive('BNTrainingUpdate')
constant0 = Tensor(0.1, mstype.float32)
constant1 = Tensor(0.1, mstype.float32)
@ -40,7 +41,7 @@ class FnDict:
return self.fnDict[name]
def useless_test_fused_batch_norm_fusion(tag):
def test_fused_batch_norm_fusion(tag):
fns = FnDict()
@fns
@ -60,9 +61,11 @@ def useless_test_fused_batch_norm_fusion(tag):
@fns
def after(input0, input1, input2, input3, input4, var0, var1):
fused_batch_norm = FusedBatchNorm(input0, input1, input2, var0, var1)
outputs = make_tuple(tuple_getitem(fused_batch_norm, 0), tuple_getitem(fused_batch_norm, 3),
tuple_getitem(fused_batch_norm, 4))
bn_training_reduce = BNTrainingReduce(input0)
bn_training_update = BNTrainingUpdate(input0, tuple_getitem(bn_training_reduce, 0),
tuple_getitem(bn_training_reduce, 1), input1, input2, var0, var1)
outputs = make_tuple(tuple_getitem(bn_training_update, 0), tuple_getitem(bn_training_update, 3),
tuple_getitem(bn_training_update, 4))
output = tuple_getitem(outputs, 0)
return make_tuple(output)

Loading…
Cancel
Save