FusedBatchNormEx dynamic shape support

pull/12439/head
tom__chen 4 years ago
parent 9ab234369b
commit 03b1aeecdb

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,25 +29,7 @@ namespace kernel {
template <typename T>
class FusedBatchNormExGpuKernel : public GpuKernel {
public:
FusedBatchNormExGpuKernel()
: input_x_size_(0),
input_z_size_(0),
para_size_(0),
output_size_(0),
workspace_size_(0),
reserve_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5),
exp_avg_factor_(0.1),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
z_desc_(nullptr),
scale_bias_mean_var_desc_(nullptr),
activation_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
FusedBatchNormExGpuKernel() { ResetResource(); }
~FusedBatchNormExGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -142,6 +124,30 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_x_size_ = 0;
input_z_size_ = 0;
para_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
reserve_size_ = 0;
mode_ = CUDNN_BATCHNORM_SPATIAL;
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
epsilon_ = 10e-5;
exp_avg_factor_ = 0.1;
is_null_input_ = false;
x_desc_ = nullptr;
y_desc_ = nullptr;
z_desc_ = nullptr;
scale_bias_mean_var_desc_ = nullptr;
activation_desc_ = nullptr;
handle_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");

@ -51,6 +51,8 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -23,6 +23,15 @@
namespace mindspore {
namespace abstract {
int64_t GetAndCheckFormat(const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) {
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC";
}
return data_format;
}
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
@ -235,6 +244,54 @@ AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const Pri
return std::make_shared<AbstractTuple>(rets);
}
AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: five tensors(x, gamma, beta, mean, variance).
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 5);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());
ShapeVector x_shape = input_x->shape()->shape();
ShapeVector x_min_shape = input_x->shape()->min_shape();
ShapeVector x_max_shape = input_x->shape()->max_shape();
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Input rank should 4.";
}
auto data_format_ptr = primitive->GetAttr("format");
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format = GetAndCheckFormat(data_format_ptr);
int64_t c_axis = 1;
if (data_format == Format::NHWC) {
c_axis = 3;
}
for (size_t i = 1; i < args_spec_list.size(); ++i) {
AbstractTensorPtr arg_spec = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
MS_EXCEPTION_IF_NULL(arg_spec);
MS_EXCEPTION_IF_NULL(arg_spec->shape());
ShapeVector arg_shape = arg_spec->shape()->shape();
if (arg_shape.size() != 1) {
MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size();
}
if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) {
MS_LOG(EXCEPTION) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis << "]=" << x_shape[c_axis]
<< ", but got " << arg_shape[0];
}
}
AbstractTensorPtr input_gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
ShapeVector gamma_shape = input_gamma->shape()->shape();
ShapeVector gamma_min_shape = input_gamma->shape()->min_shape();
ShapeVector gamma_max_shape = input_gamma->shape()->max_shape();
CheckMinMaxShape(gamma_shape, &gamma_min_shape, &gamma_max_shape);
ShapePtr output_shape_ptr = std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape);
AbstractTensorPtr output = std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
ShapePtr gamma_shape_ptr = std::make_shared<Shape>(gamma_shape, gamma_min_shape, gamma_max_shape);
AbstractTensorPtr output_gamma = std::make_shared<AbstractTensor>(input_gamma->element(), gamma_shape_ptr);
AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma, output_gamma};
return std::make_shared<AbstractTuple>(rets);
}
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance).
@ -311,15 +368,6 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
}
}
int64_t GetAndCheckFormat(const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) {
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC";
}
return data_format;
}
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();

@ -141,6 +141,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}},
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2D, {InferImplConv2D, true}},

@ -871,11 +871,12 @@ class FusedBatchNorm(Primitive):
self.target = context.get_context("device_target")
class FusedBatchNormEx(PrimitiveWithInfer):
class FusedBatchNormEx(PrimitiveWithCheck):
r"""
FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve)
than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that
moving mean and moving variance will be computed instead of being loaded.
moving mean and moving variance will be computed instead of being loaded. FusedBatchNormEx currently only
supports 4D inputs.
Batch Normalization is widely used in convolutional networks. This operation applies
Batch Normalization over input to avoid internal covariate shift as described in the
@ -899,7 +900,7 @@ class FusedBatchNormEx(PrimitiveWithInfer):
Default: "NCHW".
Inputs:
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`,
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
data type: float16 or float32.
- **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`,
data type: float32.
@ -970,25 +971,22 @@ class FusedBatchNormEx(PrimitiveWithInfer):
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def infer_shape(self, input_x, scale, bias, mean, variance):
def check_shape(self, input_x, scale, bias, mean, variance):
input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2])
validator.check_equal_int(len(input_shape_norm), 4, "x rank", self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input channel", input_shape_norm[1], Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance):
def check_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
valid_dtypes = [mstype.tensor_type(mstype.float32)]
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
return (input_x, scale, scale, scale, scale, scale)
class InstanceNorm(PrimitiveWithInfer):

@ -0,0 +1,128 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.nn import Cell
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import operations as P
class NetFusedBatchNormEx(Cell):
def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None):
super(NetFusedBatchNormEx, self).__init__()
self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1)
self.moving_mean = Parameter(initializer(
mean_init, num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer(
var_init, num_features), name="variance", requires_grad=False)
self.gamma = Parameter(initializer(
gamma_init, num_features), name="gamma", requires_grad=True)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=True)
self.dynshape = inner.GpuConvertToDynamicShape()
def construct(self, x):
x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fused_bn_ex():
x = np.array([[
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
[-0.1471, 0.7706, 1.6882, 2.6059],
[0.3118, 1.6882, 2.1471, 2.1471],
[0.7706, 0.3118, 2.6059, -0.1471]],
[[0.9119, 1.8518, 1.3819, -0.0281],
[-0.0281, 0.9119, 1.3819, 1.8518],
[2.7918, 0.4419, -0.4981, 0.9119],
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
weight = np.ones(2).astype(np.float32)
bias = np.ones(2).astype(np.float32)
moving_mean = np.ones(2).astype(np.float32)
moving_var = np.ones(2).astype(np.float32)
error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
bn_net = NetFusedBatchNormEx(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var))
output_list = bn_net(Tensor(x))
output = output_list[0]
diff = output.asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
class NetFusedBatchNormExDynamic(Cell):
def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None):
super(NetFusedBatchNormExDynamic, self).__init__()
self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1)
self.moving_mean = Parameter(initializer(
mean_init, num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer(
var_init, num_features), name="variance", requires_grad=False)
self.gamma = Parameter(initializer(
gamma_init, num_features), name="gamma", requires_grad=True)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=True)
self.dynshape = inner.GpuConvertToDynamicShape()
def construct(self, x):
x = self.dynshape(x)
x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fused_bn_ex_dynamic():
x = np.array([[
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
[-0.1471, 0.7706, 1.6882, 2.6059],
[0.3118, 1.6882, 2.1471, 2.1471],
[0.7706, 0.3118, 2.6059, -0.1471]],
[[0.9119, 1.8518, 1.3819, -0.0281],
[-0.0281, 0.9119, 1.3819, 1.8518],
[2.7918, 0.4419, -0.4981, 0.9119],
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
weight = np.ones(2).astype(np.float32)
bias = np.ones(2).astype(np.float32)
moving_mean = np.ones(2).astype(np.float32)
moving_var = np.ones(2).astype(np.float32)
error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
bn_net = NetFusedBatchNormExDynamic(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var))
output_list = bn_net(Tensor(x))
output = output_list[0]
diff = output.asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
Loading…
Cancel
Save