diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index e7d5417e6f..e44c6a5b3b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1816,8 +1816,6 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); } - UpdataParam(func_graph, cell); - // ret = cell_obj(*arg, *kwargs) auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs}); diff --git a/tests/st/networks/test_gradient_accumulation.py b/tests/st/networks/test_gradient_accumulation.py new file mode 100644 index 0000000000..54070a9b40 --- /dev/null +++ b/tests/st/networks/test_gradient_accumulation.py @@ -0,0 +1,220 @@ +import os + +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as CT +import mindspore.dataset.vision.c_transforms as CV +import mindspore.nn as nn +from mindspore import ParameterTuple +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.common.initializer import Normal +from mindspore.dataset.vision import Inter +from mindspore.nn import Cell +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.train.dataset_helper import DatasetHelper +from mindspore.train.serialization import save_checkpoint + +_sum_op = C.MultitypeFuncGraph("grad_sum_op") +_clear_op = C.MultitypeFuncGraph("clear_op") + + +@_sum_op.register("Tensor", "Tensor") +def _cumulative_gard(grad_sum, grad): + """Apply gard sum to cumulative gradient.""" + add = P.AssignAdd() + return add(grad_sum, grad) + + +@_clear_op.register("Tensor", "Tensor") +def _clear_grad_sum(grad_sum, zero): + """Apply zero to clear grad_sum.""" + success = True + success = F.depend(success, F.assign(grad_sum, zero)) + return success + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Num classes. Default: 10. + num_channel (int): Num channels. Default: 1. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + """ + def __init__(self, num_class=10, num_channel=1): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.max_pool2d(self.relu(self.conv1(x))) + x = self.max_pool2d(self.relu(self.conv2(x))) + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class TrainForwardBackward(Cell): + def __init__(self, network, optimizer, grad_sum, sens=1.0): + super(TrainForwardBackward, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.network.add_flags(defer_inline=True) + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad_sum = grad_sum + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.hyper_map = C.HyperMap() + + def construct(self, *inputs): + weights = self.weights + loss = self.network(*inputs) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*inputs, sens) + return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)) + + +class TrainOptim(Cell): + def __init__(self, optimizer, grad_sum): + super(TrainOptim, self).__init__(auto_prefix=False) + self.optimizer = optimizer + self.grad_sum = grad_sum + + def construct(self): + return self.optimizer(self.grad_sum) + + +class TrainClear(Cell): + def __init__(self, grad_sum, zeros): + super(TrainClear, self).__init__(auto_prefix=False) + self.grad_sum = grad_sum + self.zeros = zeros + self.hyper_map = C.HyperMap() + + def construct(self): + seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros) + return seccess + + +class GradientAccumulation: + def __init__(self, network, loss_fn, optimizer): + self._network = network + self._loss_fn = loss_fn + self._optimizer = optimizer + + params = self._optimizer.parameters + self._grad_sum = params.clone(prefix="grad_sum", init='zeros') + self._zeros = params.clone(prefix="zeros", init='zeros') + self._train_forward_backward = self._build_train_forward_backward_network() + self._train_optim = self._build_train_optim() + self._train_clear = self._build_train_clear() + + def _build_train_forward_backward_network(self): + """Build forward and backward network""" + network = self._network + network = nn.WithLossCell(network, self._loss_fn) + loss_scale = 1.0 + network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train() + return network + + def _build_train_optim(self): + """Build optimizer network""" + network = TrainOptim(self._optimizer, self._grad_sum).set_train() + return network + + def _build_train_clear(self): + """Build clear network""" + network = TrainClear(self._grad_sum, self._zeros).set_train() + return network + + def train_process(self, epoch, train_dataset, mini_steps=None): + """ + Training process. The data would be passed to network directly. + """ + dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) + + for i in range(epoch): + step = 0 + for k, next_element in enumerate(dataset_helper): + loss = self._train_forward_backward(*next_element) + if (k + 1) % mini_steps == 0: + step += 1 + print("epoch:", i + 1, "step:", step, "loss is ", loss) + self._train_optim() + self._train_clear() + + train_dataset.reset() + + save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",) + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + rescale_nml = 1 / 0.3081 + shift_nml = -1 * 0.1307 / 0.3081 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode + rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = CT.TypeCast(mstype.int32) + + # apply map operations on images + mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_gradient_accumulation(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32) + + network = LeNet5(10) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + model = GradientAccumulation(network, net_loss, net_opt) + + print("============== Starting Training ==============") + model.train_process(2, ds_train, mini_steps=4)