!12532 add SyncBatchNorm

From: @yuchaojie
Reviewed-by: 
Signed-off-by:
pull/12532/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 85461bcdb3

@ -276,6 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<SyncBnSplit>());
ir_fusion_pm->AddPass(std::make_shared<SyncBnGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());

@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/helper.h"
@ -104,6 +105,36 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode
MS_EXCEPTION_IF_NULL(make_tuple);
return make_tuple;
}
CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> bn_update_grad_outputs;
CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs);
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"
<< " trace: " << trace::DumpSourceLines(cnode);
}
std::vector<AnfNodePtr> allreduce_mul_outputs;
for (size_t i = 0; i < bn_update_grad_outputs.size(); ++i) {
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode);
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
}
std::vector<AnfNodePtr> bn_reduce_grad_outputs;
CreateOutputsOfReduceGrad(func_graph, cnode, allreduce_mul_outputs, &bn_reduce_grad_outputs);
if (bn_reduce_grad_outputs.size() != 1) {
MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"
<< " trace: " << trace::DumpSourceLines(cnode);
}
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0],
allreduce_mul_outputs[0], allreduce_mul_outputs[1]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
return make_tuple;
}
} // namespace
const BaseRef BnGradSplit::DefinePattern() const {
@ -120,5 +151,17 @@ const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfN
}
return BNGradSplitForTBE(func_graph, cnode);
}
const BaseRef SyncBnGradSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimSyncBatchNormGrad, Xs});
}
const AnfNodePtr SyncBnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
return SyncBNGradSplitForTBE(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore

@ -28,6 +28,14 @@ class BnGradSplit : public PatternProcessPass {
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class SyncBnGradSplit : public PatternProcessPass {
public:
explicit SyncBnGradSplit(bool multigraph = true) : PatternProcessPass("sync_bn_grad_split", multigraph) {}
~SyncBnGradSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_

@ -17,6 +17,8 @@
#include <vector>
#include <memory>
#include <string>
#include <limits>
#include "utils/utils.h"
#include "utils/ms_context.h"
@ -28,6 +30,9 @@
namespace mindspore {
namespace opt {
namespace {
constexpr auto kReduceOpSum = "sum";
constexpr auto kDeviceNum = "device_num";
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
MS_EXCEPTION_IF_NULL(graph);
@ -117,8 +122,105 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr
// Create BNTrainingUpdate node
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs);
}
AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
return nullptr;
}
// Create BNTrainingReduce node and get outputs of BNTrainingReduce
std::vector<AnfNodePtr> bn_training_reduce_outputs;
if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
return nullptr;
}
if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"
<< " trace: " << trace::DumpSourceLines(node);
}
std::vector<AnfNodePtr> allreduce_mul_outputs;
for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) {
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode);
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
}
// Create BNTrainingUpdate node
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs);
}
} // namespace
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
if (!AnfAlgo::HasNodeAttr(kDeviceNum, sync_bn_cnode)) {
MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] does not have attr device_num.";
}
auto device_num = AnfAlgo::GetNodeAttr<int64_t>(sync_bn_cnode, kDeviceNum);
MS_LOG(INFO) << "device_num value: " << device_num;
float device_num_reciprocal = 1.0 / device_num;
std::vector<int64_t> device_num_shape = {};
auto device_num_reciprocal_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, device_num_shape);
MS_EXCEPTION_IF_NULL(device_num_reciprocal_tensor);
auto data_ptr = device_num_reciprocal_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
auto *val = reinterpret_cast<float *>(data_ptr);
*val = device_num_reciprocal;
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, device_num_shape);
auto device_num_reciprocal_value = kernel_graph->NewValueNode(abstract, device_num_reciprocal_tensor);
MS_EXCEPTION_IF_NULL(device_num_reciprocal_value);
kernel_graph->AddValueNodeToGraph(device_num_reciprocal_value);
return device_num_reciprocal_value;
}
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
const CNodePtr &sync_bn_cnode) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(allreduce_input);
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
// create AllReduce
std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)),
allreduce_input};
auto allreduce = graph->NewCNode(allreduce_inputs);
MS_EXCEPTION_IF_NULL(allreduce);
allreduce->set_abstract(allreduce_input->abstract());
allreduce->set_scope(allreduce_input->scope());
AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce);
AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce);
// use SyncBatchNorm's opid as AllReduce's fusion attr
auto sync_bn_opname = sync_bn_cnode->fullname_with_scope();
auto opid_pos = sync_bn_opname.rfind("-op");
if (opid_pos == std::string::npos) {
MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] has no opid.";
}
int64_t opid = std::stol(sync_bn_opname.substr(opid_pos + 3));
// user defined fusion should be greater than 1
if (opid < 2) {
opid = opid - 2 + std::numeric_limits<int64_t>::max();
}
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(opid), allreduce);
// create Mul
auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode);
std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce,
device_num_reciprocal_vnode};
auto mul = graph->NewCNode(mul_inputs);
MS_EXCEPTION_IF_NULL(mul);
mul->set_abstract(allreduce_input->abstract());
mul->set_scope(allreduce_input->scope());
return mul;
}
const BaseRef BnSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
@ -132,5 +234,14 @@ const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodeP
}
return SplitBatchNormForTBE(func_graph, node);
}
const BaseRef SyncBnSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimSyncBatchNorm, Xs});
}
const AnfNodePtr SyncBnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
return SyncBNSplitForTBE(func_graph, node);
}
} // namespace opt
} // namespace mindspore

@ -28,6 +28,19 @@ class BnSplit : public PatternProcessPass {
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class SyncBnSplit : public PatternProcessPass {
public:
explicit SyncBnSplit(bool multigraph = true) : PatternProcessPass("sync_bn_split", multigraph) {}
~SyncBnSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode);
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
const CNodePtr &sync_bn_cnode);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_

@ -228,6 +228,8 @@ inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx");
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
inline const PrimitivePtr kPrimSyncBatchNorm = std::make_shared<Primitive>("SyncBatchNorm");
inline const PrimitivePtr kPrimSyncBatchNormGrad = std::make_shared<Primitive>("SyncBatchNormGrad");
inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
inline const PrimitivePtr kPrimReluGradV2 = std::make_shared<Primitive>("ReluGradV2");
inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad");

File diff suppressed because it is too large Load Diff

@ -17,6 +17,8 @@
from .. import operations as P
from .. import composite as C
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprop_getters
@ -64,5 +66,20 @@ def bprop_pqc(self):
dx = t(dx, (1, 0))
dy = C.tensor_dot(dout[0], out[2], ((0, 1), (0, 1)))
return dx, dy
return bprop
@bprop_getters.register(inner.SyncBatchNorm)
def get_bprop_sync_batch_norm(self):
"""Grad definition for `SyncBatchNorm` operation."""
input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num)
def bprop(x, scale, b, mean, variance, out, dout):
saved_mean = out[3]
saved_variance = out[4]
out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
dx = out[0]
dscale = out[1]
dbias = out[2]
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
return bprop

@ -204,6 +204,24 @@ class BatchNormGrad(PrimitiveWithInfer):
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
class SyncBatchNormGrad(PrimitiveWithInfer):
"""Performs grad of SyncBatchNorm operation."""
@prim_attr_register
def __init__(self, epsilon=1e-5, group="group0", device_num=2):
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
if not isinstance(group, str):
raise TypeError("The group attr of SyncBatchNormGrad should be str.")
validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape):
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
return (x_shape, scale_shape, scale_shape)
def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape):
return (x_type, scale_type, scale_type)
class BiasAddGrad(PrimitiveWithInfer):
"""Computes gradients of BiasAdd."""

@ -630,6 +630,7 @@ class GpuConvertToDynamicShape(PrimitiveWithCheck):
def check_dtype(self, input_dtype):
validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
"""
This op is used for dynamic shape testing. The only purpose of this operator is
@ -724,3 +725,93 @@ class SequenceMask(PrimitiveWithCheck):
def check_dtype(self, lengths_dtype, maxlen_dtype):
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
class SyncBatchNorm(PrimitiveWithInfer):
r"""
Sync Batch Normalization for input data and updated parameters.
Sync Batch Normalization is cross device synchronized batch normalization. Batch Normalization is
widely used in convolutional neural networks. This operation applies Batch Normalization over input
to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
It rescales and recenters the features using a mini-batch of data and the learned parameters which
can be described in the following formula,
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
Args:
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
Momentum value must be [0, 1]. Default: 0.1.
group (str): The communication group to work on. Default: "sync_bn_group0".
device_num (int): The number of devices in each group. Default: 2.
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
- **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
- **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
- **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
- **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
Outputs:
Tuple of 5 Tensor, the normalized inputs and the updated parameters.
- **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
Supported Platforms:
``Ascend``
Examples:
>>> # This example should be run with multiple processes.
>>> # Please refer to nn.SyncBatchNorm for direct use.
>>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
>>> scale = Tensor(np.ones([2]), mindspore.float32)
>>> bias = Tensor(np.ones([2]), mindspore.float32)
>>> mean = Tensor(np.ones([2]), mindspore.float32)
>>> variance = Tensor(np.ones([2]), mindspore.float32)
>>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
>>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00, 1.00000000e+00],
[ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
[ 1.00000000e+00, 1.00000000e+00]))
"""
@prim_attr_register
def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
validator.check_isinstance("group", group, str)
validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x channel", input_x[1], Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name)
return (input_x, scale, bias, input_x, input_x)

@ -100,5 +100,67 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_grad_split", "after2");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBnGradSplit, test_sync_bn_grad_split_tbe) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_x{1, 64, 112, 112};
std::vector<int64_t> shp_b{64};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract};
auto kernel_graph = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kernel_graph, nullptr);
// get SyncBNGrad
CNodePtr ret = kernel_graph->get_return();
EXPECT_NE(ret, nullptr);
EXPECT_NE(ret->input(1), nullptr);
EXPECT_TRUE(ret->input(1)->isa<CNode>());
auto make_tuple1 = ret->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple1->input(1), nullptr);
EXPECT_TRUE(make_tuple1->input(1)->isa<CNode>());
auto make_tuple2 = make_tuple1->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple2->input(1), nullptr);
EXPECT_TRUE(make_tuple2->input(1)->isa<CNode>());
auto tuple_getitem = make_tuple2->input(1)->cast<CNodePtr>();
EXPECT_NE(tuple_getitem->input(1), nullptr);
EXPECT_TRUE(tuple_getitem->input(1)->isa<CNode>());
auto bn_grad = tuple_getitem->input(1)->cast<CNodePtr>();
// get param1
EXPECT_NE(bn_grad->input(1), nullptr);
auto param1 = bn_grad->input(1);
// set kernel for param1
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2;
builder2.SetOutputsFormat({kOpFormat_NC1HWC0});
builder2.SetOutputsDeviceType({kNumberTypeFloat32});
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param1.get());
// set kernel for SyncBNGrad
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetOutputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetInputsDeviceType(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder1.SetOutputsDeviceType(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder1.SetKernelType(TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get());
// do sync_bn_grad_split pass
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::SyncBnGradSplit>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kernel_graph);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

@ -86,7 +86,7 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) {
builder.SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get());
// do bn_grad_split_pass
// do bn_split_pass
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::BnSplit>();
@ -97,5 +97,54 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_split_tbe", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBnSplit, test_sync_bn_split_tbe) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_x{1, 64, 112, 112};
std::vector<int64_t> shp_b{64};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
AbstractBasePtrList args_spec_list{x_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
auto kernel_graph = GetKernelGraph(g, args_spec_list);
// get kernel
auto ret = kernel_graph->get_return();
EXPECT_NE(ret, nullptr);
EXPECT_TRUE(ret->inputs().size() == 2);
auto make_tuple = ret->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple, nullptr);
EXPECT_TRUE(make_tuple->inputs().size() == 2);
auto item0 = make_tuple->input(1)->cast<CNodePtr>();
EXPECT_NE(item0, nullptr);
EXPECT_TRUE(item0->inputs().size() == 3);
auto bn = item0->input(1);
EXPECT_NE(bn, nullptr);
EXPECT_TRUE(bn->isa<CNode>());
// set kernel for SyncBN
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetOutputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetInputsDeviceType(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder.SetOutputsDeviceType(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder.SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get());
// do sync_bn_split_pass
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::SyncBnSplit>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kernel_graph);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

@ -16,15 +16,21 @@
from mindspore.ops import Primitive
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import _constants as Constants
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
bn_grad = G.BatchNormGrad(is_training=True)
sync_bn_grad = G.SyncBatchNormGrad()
bn_grad1 = Primitive('BNGrad1')
bn_grad2 = Primitive('BNGrad2')
bn_grad3 = Primitive('BNGrad3')
bn_training_update_grad = Primitive('BNTrainingUpdateGrad')
bn_training_reduce_grad = Primitive('BNTrainingReduceGrad')
allreduce = Primitive('AllReduce')
mul = Primitive('Mul')
mul_value = Tensor(0.5, mstype.float32)
class FnDict:
@ -85,3 +91,36 @@ def test_bn_grad_split(tag):
return make_tuple(output)
return fns[tag]
def test_sync_bn_grad_split(tag):
""" test_sync_bn_grad_split """
fns = FnDict()
@fns
def before(i0, i1, i2, i3, i4):
bn_grad_output = sync_bn_grad(i0, i1, i2, i3, i4)
item0 = tuple_getitem(bn_grad_output, 0)
item1 = tuple_getitem(bn_grad_output, 1)
item2 = tuple_getitem(bn_grad_output, 2)
output = make_tuple(item0, item1, item2)
return output
@fns
def after(i0, i1, i2, i3, i4):
bn_update_grad_output = bn_training_update_grad(i0, i1, i3, i4)
update_output0 = tuple_getitem(bn_update_grad_output, 0)
update_output1 = tuple_getitem(bn_update_grad_output, 1)
allreduce_output0 = allreduce(update_output0)
allreduce_output1 = allreduce(update_output1)
update_item0 = mul(allreduce_output0, mul_value)
update_item1 = mul(allreduce_output1, mul_value)
bn_reduce_grad_output = bn_training_reduce_grad(i0, i1, update_item0, update_item1, i2, i3, i4)
output = make_tuple(bn_reduce_grad_output, update_item0, update_item1)
item0 = tuple_getitem(output, 0)
item1 = tuple_getitem(output, 1)
item2 = tuple_getitem(output, 2)
output = make_tuple(item0, item1, item2)
return make_tuple(output)
return fns[tag]

@ -15,16 +15,23 @@
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import _constants as Constants
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
bn = P.BatchNorm(is_training=True)
sync_bn = inner.SyncBatchNorm()
fused_bn1 = Primitive('FusedBN1')
fused_bn2 = Primitive('FusedBN2')
fused_bn3 = Primitive('FusedBN3')
bn_training_reduce = Primitive('BNTrainingReduce')
bn_training_update = Primitive('BNTrainingUpdate')
allreduce = Primitive('AllReduce')
mul = Primitive('Mul')
mul_value = Tensor(0.5, mstype.float32)
class FnDict:
@ -89,3 +96,30 @@ def test_bn_split_tbe(tag):
return make_tuple(output)
return fns[tag]
def test_sync_bn_split_tbe(tag):
""" test_sync_split_bn_fusion """
fns = FnDict()
@fns
def before(x, scale, b, mean, variance):
bn_output = sync_bn(x, scale, b, mean, variance)
output = tuple_getitem(bn_output, 0)
return output
@fns
def after(x, scale, b, mean, variance):
bn_training_reduce_output = bn_training_reduce(x)
bn_training_reduce_output0 = tuple_getitem(bn_training_reduce_output, 0)
bn_training_reduce_output1 = tuple_getitem(bn_training_reduce_output, 1)
allreduce_output0 = allreduce(bn_training_reduce_output0)
allreduce_output1 = allreduce(bn_training_reduce_output1)
bn_training_update_input1 = mul(allreduce_output0, mul_value)
bn_training_update_input2 = mul(allreduce_output1, mul_value)
bn_training_update_output = bn_training_update(x, bn_training_update_input1, bn_training_update_input2,
scale, b, mean, variance)
output = tuple_getitem(bn_training_update_output, 0)
return make_tuple(output)
return fns[tag]

@ -1755,6 +1755,16 @@ test_case_nn_ops = [
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}),
('SyncBatchNorm', {
'block': inner.SyncBatchNorm(),
'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': []}),
('SyncBatchNormGrad', {
'block': G.SyncBatchNormGrad(),
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}),
('TopK', {
'block': P.TopK(),
'desc_const': [5],

Loading…
Cancel
Save