From b16686964c2ea754bb3c21217efc0cc118402278 Mon Sep 17 00:00:00 2001 From: lvliang Date: Mon, 6 Jul 2020 21:11:55 +0800 Subject: [PATCH] back-to-fusedbatchnorm-operation-in-pynative-mode --- .../ascend/ascend_backend_optimization.cc | 20 +- mindspore/nn/layer/normalization.py | 5 +- ...ascend_lenet.py => test_pynative_lenet.py} | 3 +- tests/st/pynative/test_pynative_resnet50.py | 432 ++++++++++++++++++ 4 files changed, 448 insertions(+), 12 deletions(-) rename tests/st/pynative/{test_ascend_lenet.py => test_pynative_lenet.py} (98%) create mode 100644 tests/st/pynative/test_pynative_resnet50.py diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index ff864401b1..981e2255f3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -238,11 +238,16 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + if (context_ptr->execution_mode() == kPynativeMode) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } else { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); @@ -282,11 +287,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index d6c920b620..4c7ea9d4d6 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -84,13 +84,14 @@ class _BatchNorm(Cell): self.dtype = P.DType() self.reshape = P.Reshape() self.is_ascend = context.get_context("device_target") == "Ascend" + self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE self.momentum = 1.0 - momentum if context.get_context("enable_ge"): self.is_ge_backend = True else: self.is_ge_backend = False - if self.is_ge_backend or self.is_ascend: + if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps) else: @@ -152,7 +153,7 @@ class _BatchNorm(Cell): if self.is_ge_backend and self.is_global: axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) - elif self.is_ge_backend or self.is_ascend: + elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): if self.is_global: axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) diff --git a/tests/st/pynative/test_ascend_lenet.py b/tests/st/pynative/test_pynative_lenet.py similarity index 98% rename from tests/st/pynative/test_ascend_lenet.py rename to tests/st/pynative/test_pynative_lenet.py index 021c71d9cd..c6166d0517 100644 --- a/tests/st/pynative/test_ascend_lenet.py +++ b/tests/st/pynative/test_pynative_lenet.py @@ -157,4 +157,5 @@ def test_ascend_pynative_lenet(): total_time = total_time + cost_time print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - assert loss_output.asnumpy() < 0.1 + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 diff --git a/tests/st/pynative/test_pynative_resnet50.py b/tests/st/pynative/test_pynative_resnet50.py new file mode 100644 index 0000000000..21d761dfcc --- /dev/null +++ b/tests/st/pynative/test_pynative_resnet50.py @@ -0,0 +1,432 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import time +import random +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision +import mindspore.nn as nn +import mindspore.ops.functional as F + +from mindspore import Tensor +from mindspore import context +from mindspore import ParameterTuple +from mindspore.nn import Cell +from mindspore.ops import operations as P +from mindspore.ops import composite as CP +from mindspore.nn.optim.momentum import Momentum +from mindspore.common.initializer import initializer +from mindspore.nn.wrap.cell_wrapper import WithLossCell + +random.seed(1) +np.random.seed(1) +ds.config.set_seed(1) + + +def weight_variable(shape): + return initializer('XavierUniform', shape=shape, dtype=mstype.float32) + + +def weight_variable_uniform(shape): + return initializer('Uniform', shape=shape, dtype=mstype.float32) + + +def weight_variable_0(shape): + zeros = np.zeros(shape).astype(np.float32) + return Tensor(zeros) + + +def weight_variable_1(shape): + ones = np.ones(shape).astype(np.float32) + return Tensor(ones) + + +def conv3x3(in_channels, out_channels, stride=1, padding=0): + """3x3 convolution """ + weight_shape = (out_channels, in_channels, 3, 3) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv1x1(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 1, 1) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv7x7(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 7, 7) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def bn_with_initialize(out_channels): + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_uniform(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def bn_with_initialize_last(out_channels): + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_uniform(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def fc_with_initialize(input_channels, out_channels): + weight_shape = (out_channels, input_channels) + weight = weight_variable(weight_shape) + bias_shape = (out_channels) + bias = weight_variable_uniform(bias_shape) + return nn.Dense(input_channels, out_channels, weight, bias) + + +class ResidualBlock(nn.Cell): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1): + super(ResidualBlock, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResidualBlockWithDown(nn.Cell): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1, + down_sample=False): + super(ResidualBlockWithDown, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.downSample = down_sample + + self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) + self.bn_down_sample = bn_with_initialize(out_channels) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.conv_down_sample(identity) + identity = self.bn_down_sample(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class MakeLayer0(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer0, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) + self.b = block(out_channels, out_channels, stride=stride) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class MakeLayer1(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer1, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + + return x + + +class MakeLayer2(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer2, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + self.e = block(out_channels, out_channels, stride=1) + self.f = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + x = self.e(x) + x = self.f(x) + + return x + + +class MakeLayer3(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer3, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class ResNet(nn.Cell): + + def __init__(self, block, num_classes=100, batch_size=32): + super(ResNet, self).__init__() + self.batch_size = batch_size + self.num_classes = num_classes + + self.conv1 = conv7x7(3, 64, stride=2, padding=0) + + self.bn1 = bn_with_initialize(64) + self.relu = P.ReLU() + self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME") + + self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) + self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) + self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) + self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) + + self.pool = P.ReduceMean(keep_dims=True) + self.squeeze = P.Squeeze(axis=(2, 3)) + self.fc = fc_with_initialize(512 * block.expansion, num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x)[0] + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.pool(x, (2, 3)) + x = self.squeeze(x) + x = self.fc(x) + return x + + +def resnet50(batch_size, num_classes): + return ResNet(ResidualBlock, num_classes, batch_size) + + +def create_dataset(repeat_num=1, training=True, batch_size=32): + data_home = "/home/workspace/mindspore_dataset" + data_dir = data_home + "/cifar-10-batches-bin" + if not training: + data_dir = data_home + "/cifar-10-verify-bin" + data_set = ds.Cifar10Dataset(data_dir) + + resize_height = 224 + resize_width = 224 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) + rescale_op = vision.Rescale(rescale, shift) + normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, + changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply shuffle operations + data_set = data_set.shuffle(buffer_size=1000) + + # apply batch operations + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + return data_set + + +class CrossEntropyLoss(nn.Cell): + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero) + loss = self.cross_entropy(logits, label)[0] + loss = self.mean(loss, (-1,)) + return loss + + +class GradWrap(Cell): + """ GradWrap definition """ + + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + + def construct(self, x, label): + weights = self.weights + return CP.grad_by_list(self.network, weights)(x, label) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_resnet50(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + + batch_size = 32 + num_classes = 10 + net = resnet50(batch_size, num_classes) + criterion = CrossEntropyLoss() + optimizer = Momentum(learning_rate=0.01, momentum=0.9, + params=filter(lambda x: x.requires_grad, net.get_parameters())) + + net_with_criterion = WithLossCell(net, criterion) + net_with_criterion.set_grad() + train_network = GradWrap(net_with_criterion) + train_network.set_train() + + step = 0 + max_step = 20 + data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) + for element in data_set.create_dict_iterator(): + step = step + 1 + if step > max_step: + break + start_time = time.time() + input_data = Tensor(element["image"]) + input_label = Tensor(element["label"]) + loss_output = net_with_criterion(input_data, input_label) + grads = train_network(input_data, input_label) + optimizer(grads) + end_time = time.time() + cost_time = end_time - start_time + print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) + if step > 1: + assert cost_time < 0.5 + \ No newline at end of file