modify batch_normal

pull/11132/head
lilei 4 years ago
parent f87b5e0cc8
commit 9a45c4419c

@ -155,7 +155,6 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
@ -270,15 +269,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
}
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());

@ -262,7 +262,8 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
const std::vector<size_t> &origin_shape, const TypeId &origin_type,
const std::vector<Axis> &reshape_type) {
MS_EXCEPTION_IF_NULL(func_graph);
std::string input_format = format;
std::string output_format = format;
@ -272,6 +273,8 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
builder.SetInputsReshapeType({reshape_type});
builder.SetOutputsReshapeType({reshape_type});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetFusionType(kernel::FusionType::OPAQUE);

@ -96,7 +96,8 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type);
const std::vector<size_t> &origin_shape, const TypeId &origin_type,
const std::vector<Axis> &reshape_type = std::vector<Axis>{});
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);

@ -143,7 +143,12 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
}
// insert depend
if (origin_format != cur_format || origin_type != cur_type) {
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node};
std::vector<AnfNodePtr> depend_nodes;
if (get_item.get() != nullptr) {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
} else {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
}
final_node = func_graph->NewCNode(depend_nodes);
MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString();
}

@ -58,8 +58,8 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
if (origin_type != device_type) {
replace_node =
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type);
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
infer_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
@ -107,8 +107,8 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
AnfNodePtr replace_node = cnode;
if (origin_type != device_type) {
replace_node =
AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type);
replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape,
infer_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);

@ -114,12 +114,16 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode
const BaseRef BnGradSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimFusedBatchNormGrad, Xs});
return VectorRef({prim::kPrimBatchNormGrad, Xs});
}
const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (!GetBoolAttr(cnode, kAttrIsTraining)) {
MS_LOG(INFO) << "is training should be true if do fusion";
return nullptr;
}
return BNGradSplitForTBE(func_graph, cnode);
}
} // namespace opt

@ -96,7 +96,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
return bn_training_update;
}
AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
@ -125,11 +125,15 @@ AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNo
const BaseRef BnSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
return VectorRef({prim::kPrimFusedBatchNorm, Xs});
return VectorRef({prim::kPrimBatchNorm, Xs});
}
const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
return SplitFusedBatchNormForTBE(func_graph, node);
if (!GetBoolAttr(node, kAttrIsTraining)) {
MS_LOG(INFO) << "is training should be true if do fusion";
return nullptr;
}
return SplitBatchNormForTBE(func_graph, node);
}
} // namespace opt
} // namespace mindspore

@ -483,6 +483,9 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
<< " trace: " << trace::DumpSourceLines(anf_node);
}
if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
return VisitKernelWithReturnType(anf_node, 0, visit_nop_node);
}
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node);
return VisitKernelWithReturnType(input_node, 0, visit_nop_node);

@ -87,8 +87,7 @@ class _BatchNorm(Cell):
self.cast = P.Cast()
self.dtype = P.DType()
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_gpu = context.get_context("device_target") == "GPU"
self._target = context.get_context("device_target")
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
@ -96,22 +95,21 @@ class _BatchNorm(Cell):
else:
self.is_ge_backend = False
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
if self._target == "Ascend":
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
elif self.is_gpu:
epsilon=self.eps,
momentum=self.momentum)
if self._target == "GPU":
self.bn_train = P.FusedBatchNormEx(mode=1,
epsilon=self.eps,
momentum=self.momentum,
data_format=self.format)
else:
if self._target == "CPU":
self.bn_train = P.FusedBatchNorm(mode=1,
epsilon=self.eps,
momentum=self.momentum)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend))
self.enable_default_train = self.is_graph_mode and not self.is_global and \
(self.is_ge_backend or self.is_ascend)
data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
@ -168,21 +166,6 @@ class _BatchNorm(Cell):
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
return self._global_sync(x, axes, re_shape)
if self.enable_default_train:
y, batch_mean, batch_var, _, _ = self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
return y
return self.bn_train(x,
self.gamma,
self.beta,

@ -426,15 +426,14 @@ class Conv2dBnFoldQuantOneConv(Cell):
self.quant_dtype = quant_dtype
data_format = 'NCHW'
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
self.is_gpu = context.get_context('device_target') == "GPU"
self.is_ascend = context.get_context('device_target') == "Ascend"
self._target = context.get_context("device_target")
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
if context.get_context("enable_ge"):
self.is_ge_backend = True
else:
self.is_ge_backend = False
self.enable_default_train = self.is_graph_mode and \
(self.is_ge_backend or self.is_ascend)
(self.is_ge_backend or self._target == "Ascend")
# initialize convolution op and Parameter
self.conv = P.Conv2D(out_channel=out_channels,
@ -468,15 +467,16 @@ class Conv2dBnFoldQuantOneConv(Cell):
channel_axis=channel_axis,
num_channels=out_channels,
quant_dtype=quant_dtype)
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
if self._target == "Ascend":
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
elif self.is_gpu:
epsilon=self.eps,
momentum=self.momentum)
if self._target == "GPU":
self.bn_train = P.FusedBatchNormEx(mode=1,
epsilon=self.eps,
momentum=self.momentum,
data_format=self.format)
else:
if self._target == "CPU":
self.bn_train = P.FusedBatchNorm(mode=1,
epsilon=self.eps,
momentum=self.momentum)
@ -520,21 +520,6 @@ class Conv2dBnFoldQuantOneConv(Cell):
else:
conv_orig = conv / scale_factor
if self.training:
if self.enable_default_train:
out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean))
out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance))
return out
return self.bn_train(conv_orig,
self.gamma,
self.beta,

@ -1058,9 +1058,10 @@ class BatchNorm(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5, data_format="NCHW"):
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
validator.check_value_type('is_training', is_training, (bool,), self.name)
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")

@ -28,7 +28,6 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from .config import cfg
from .fused_layer_norm import FusedLayerNorm
from .lr_generator import get_bert_damping
from .thor_layer import Dense_Thor, Embedding_Thor
@ -277,11 +276,7 @@ class BertOutput(nn.Cell):
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.TensorAdd()
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):

@ -1,127 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape))
self.beta = Parameter(initializer(
beta_init, normalized_shape))
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
self.mul = P.Mul()
self.add = P.TensorAdd()
def construct(self, input_x):
"""construct of FusedLayerNorm"""
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = self.mul(output, self.gamma)
y = self.add(y, self.beta)
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s

@ -25,7 +25,6 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .fused_layer_norm import FusedLayerNorm
class BertConfig:
@ -251,11 +250,7 @@ class BertOutput(nn.Cell):
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.TensorAdd()
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):

@ -1,120 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.common.dtype as mstype
from mindspore.nn.cell import Cell
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
def construct(self, input_x):
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s

@ -81,10 +81,12 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetOutputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetInputsDeviceType(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder1.SetOutputsDeviceType({kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder1.SetOutputsDeviceType(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
builder1.SetKernelType(TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get());
// do bn_grad_split pass

@ -18,7 +18,7 @@ from mindspore.ops.operations import _grad_ops as G
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
bn_grad = G.FusedBatchNormGrad()
bn_grad = G.BatchNormGrad(is_training=True)
bn_grad1 = Primitive('BNGrad1')
bn_grad2 = Primitive('BNGrad2')
bn_grad3 = Primitive('BNGrad3')

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
bn = P.FusedBatchNorm()
bn = P.BatchNorm(is_training=True)
fused_bn1 = Primitive('FusedBN1')
fused_bn2 = Primitive('FusedBN2')
fused_bn3 = Primitive('FusedBN3')

@ -32,7 +32,7 @@ from tests.dataset_mock import MindData
dev_num = 8
strategy_weight = ((dev_num, 1, 1, 1), (1, 1, 1, 1))
strategy_bn = ((dev_num, 1, 1, 1), (1,), (1,))
strategy_bn = ((dev_num, 1, 1, 1), (1,), (1,), (1,), (1,))
strategy_fc_weight_bias = ((dev_num, 1), (1, 1), (1,))

@ -37,7 +37,7 @@ dev_num = 8
strategy_no_weight = ((dev_num, 1, 1, 1),)
strategy_weight = ((dev_num, 1, 1, 1), (1, 1, 1, 1))
strategy_add = ((dev_num, 1, 1, 1), (dev_num, 1, 1, 1))
strategy_bn = ((dev_num, 1, 1, 1), (1,), (1,))
strategy_bn = ((dev_num, 1, 1, 1), (1,), (1,), (1,), (1,))
strategy_fc_weight_nobias = ((1, dev_num), (1, dev_num))
strategy_tensor_add = ((1, dev_num), (dev_num,))

Loading…
Cancel
Save