!2887 Reset to fusedbatchnorm operation in pynative mode

Merge pull request !2887 from JoyLvliang/back-to-fusedbatchnorm-operation-in-pynative-mode
pull/2887/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b71cb28b0a

@ -238,11 +238,16 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
if (context_ptr->execution_mode() == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
}
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
@ -282,11 +287,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());

@ -84,13 +84,14 @@ class _BatchNorm(Cell):
self.dtype = P.DType()
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
self.is_ge_backend = True
else:
self.is_ge_backend = False
if self.is_ge_backend or self.is_ascend:
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
else:
@ -152,7 +153,7 @@ class _BatchNorm(Cell):
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
elif self.is_ge_backend or self.is_ascend:
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)

@ -157,4 +157,5 @@ def test_ascend_pynative_lenet():
total_time = total_time + cost_time
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
assert loss_output.asnumpy() < 0.1
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save