From 7b48d059c5f2abd3b07f53a2a06209513e57daec Mon Sep 17 00:00:00 2001 From: linqingke Date: Thu, 25 Mar 2021 14:04:46 +0800 Subject: [PATCH] new add grad accumulation for network. --- .../optimizer/pass/optimize_dependence.cc | 85 +++++++++++++++---- .../optimizer/auto_monad_eliminate.cc | 25 +++++- mindspore/nn/wrap/cell_wrapper.py | 64 +++++++++++++- mindspore/nn/wrap/loss_scale.py | 5 +- 4 files changed, 156 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc index 8fbed9c8a0..58af92d8c1 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -27,23 +27,86 @@ namespace mindspore { namespace opt { constexpr auto kSingleInputIndex = 1; +constexpr auto kIsolatedDependRealInputIndex = 0; +constexpr auto kIsolatedDependVirtualInputIndex = 1; namespace { +CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &new_depend_inputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + auto kernel_graph = func_graph->cast>(); + CNodePtr new_depend = nullptr; + if (kernel_graph == nullptr) { + new_depend = func_graph->NewCNode(new_depend_inputs); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_abstract(cnode->abstract()); + new_depend->set_scope(cnode->scope()); + } else { + new_depend = kernel_graph->NewCNode(cnode); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_inputs(new_depend_inputs); + } + func_graph->manager()->Replace(cnode, new_depend); + return new_depend; +} + +CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) { + return nullptr; + } + auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex); + if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) { + return nullptr; + } + auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex); + if (!real_input_op->isa()) { + return nullptr; + } + auto real_input_cnode = real_input_op->cast(); + return real_input_cnode; +} + +AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const CNodePtr &eliminate_node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(eliminate_node); + auto replace_node = eliminate_node->input(kSingleInputIndex); + std::vector new_depend_inputs = cnode->inputs(); + new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node; + auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs); + auto new_node = new_cnode->cast(); + return new_node; +} + AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return nullptr; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - string op_name = AnfAlgo::GetCNodeName(cnode); + auto replace_cnode = cnode; + // Process updatestate and depend as isolated node env. + auto isolated_cnode = CheckIsolatedVirtualNode(replace_cnode); + if (isolated_cnode != nullptr) { + replace_cnode = isolated_cnode; + } + string op_name = AnfAlgo::GetCNodeName(replace_cnode); // Currently we only eliminate transdata or cast nodes. if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { return nullptr; } - if (!IsNotRealUsedByOthers(func_graph, cnode)) { + if (!IsNotRealUsedByOthers(func_graph, replace_cnode)) { return nullptr; } - CheckCNodeInputSize(cnode, kSingleInputIndex); + CheckCNodeInputSize(replace_cnode, kSingleInputIndex); + if (isolated_cnode != nullptr) { + auto new_depend_node = EliminateIsolatedVirtualNodeInput(func_graph, cnode, replace_cnode); + return new_depend_node; + } return cnode->input(kSingleInputIndex); } @@ -137,20 +200,10 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con return nullptr; } new_depend_inputs[replace_index] = replace_node; - // Because depend's input has been changed, so a new depend(UpdateState) node will be created to replaced the old one. - auto kernel_graph = func_graph->cast>(); - CNodePtr new_depend = nullptr; - if (kernel_graph == nullptr) { - new_depend = func_graph->NewCNode(new_depend_inputs); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_abstract(depend_cnode->abstract()); - new_depend->set_scope(depend_cnode->scope()); - } else { - new_depend = kernel_graph->NewCNode(depend_cnode); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_inputs(new_depend_inputs); + auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs); + if (new_depend == nullptr) { + return nullptr; } - func_graph->manager()->Replace(depend_cnode, new_depend); return nullptr; } diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index 1102d3cca3..be31d880e9 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -107,6 +107,20 @@ std::vector> SplitGroup(const std::vector &topos // a = Load(para1, u1) // ... // b = Load(para1, u2) +// u3 = UpdateState(u2, b) +//==> +// delete the UpdateState +void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, + const AnfNodePtr &load) { + const auto &load_cnode = load->cast(); + const auto &u = load_cnode->input(2); + manager->Replace(load_user, u); +} + +// Pattern2====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) // t = make_tuple(x, b) // u3 = UpdateState(u2, t) //==> @@ -127,7 +141,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr manager->Replace(make_tuple, other_input); } -// Pattern2====================================== +// Pattern3====================================== // a = Load(para1, u1) // ... // b = Load(para1, u2) @@ -153,6 +167,11 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { auto load_users = manager->node_users()[load]; for (const auto &load_user : load_users) { + // Pattern1 + if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { + DeleteLoadUserUpdateState(manager, load_user.first, load); + continue; + } if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { const auto &make_tuple = load_user.first->cast(); auto &maketuple_users = manager->node_users()[make_tuple]; @@ -161,12 +180,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, if (!maketuple_as_input_of_update) { continue; } - // Pattern1 + // Pattern2 if (make_tuple->size() == 3) { DeleteLoadUserMakeTuple(manager, make_tuple, load); continue; } - // Pattern2 + // Pattern3 if (make_tuple->size() > 3) { ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); } diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index b143a736bc..ec17e27578 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -17,10 +17,11 @@ from types import FunctionType, MethodType from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, _get_parallel_mode) -from mindspore.context import ParallelMode +from mindspore.context import ParallelMode, get_auto_parallel_context from mindspore._checkparam import Validator as validator from ...common import dtype as mstype from ...common.parameter import Parameter, ParameterTuple +from ...common.tensor import Tensor from ...ops import composite as C from ...ops import functional as F from ...ops import operations as P @@ -62,6 +63,19 @@ def _tensors_cast_datatype(datatype, param): """ return F.cast(param, datatype) +_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op") + +@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor") +def _cumulative_grad(accumulation_step, cumulative_grad, grad): + """Apply gradient accumulation to cumulative grad.""" + return P.AssignAdd()(cumulative_grad, grad / accumulation_step) + +_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op") + +@_gradient_clear_op.register("Tensor") +def _clear_grad(cumulative_grad): + zero_grad = P.ZerosLike()(cumulative_grad) + return F.assign(cumulative_grad, zero_grad) class WithLossCell(Cell): r""" @@ -347,15 +361,28 @@ class TrainOneStepCell(Cell): self.mean = _get_gradients_mean() self.degree = _get_device_num() self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) + self.use_grad_accumulation = False + if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE): + self.use_grad_accumulation = True + if self.use_grad_accumulation: + self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step") + if self.max_accumulation_step <= 1: + self.max_accumulation_step = 1 + self.use_grad_accumulation = False + if self.use_grad_accumulation: + self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer) def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - sens = F.depend(sens, loss) grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad_reducer(grads) - loss = F.depend(loss, self.optimizer(grads)) + + if self.use_grad_accumulation: + loss = self.grad_accumulation(loss, grads) + else: + loss = F.depend(loss, self.optimizer(grads)) return loss @@ -557,3 +584,34 @@ class _BroadCastCell(Cell): params = self.broadcast(params) new_params = self.map_(F.partial(_cast_datatype), datatypes, params) return new_params + +class GradientAccumulation(Cell): + """ + After accumulating the gradients of multiple steps, call to optimize its update. + + Args: + max_accumulation_step (int): Steps to accumulate gradients. + optimizer(Cell):Optimizer used. + """ + def __init__(self, max_accumulation_step, optimizer): + super(GradientAccumulation, self).__init__() + self._max_accumulation_step = max_accumulation_step + self.optimizer = optimizer + self.weights = optimizer.parameters + self.hyper_map = C.HyperMap() + self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros') + self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step") + + def construct(self, loss, grads): + loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step), + self._grad_accumulation, grads)) + self._accumulation_step += 1 + + if self._accumulation_step >= self._max_accumulation_step: + loss = F.depend(loss, self.optimizer(self._grad_accumulation)) + self._accumulation_step = 0 + + if self._accumulation_step == 0: + loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation)) + + return loss diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 71e882c6c5..4439cde298 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -319,7 +319,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): overflow = self.process_loss_scale(cond) # if there is no overflow, do optimize if not overflow: - loss = F.depend(loss, self.optimizer(grads)) + if self.use_grad_accumulation: + loss = self.grad_accumulation(loss, grads) + else: + loss = F.depend(loss, self.optimizer(grads)) return loss, cond, scaling_sens def set_sense_scale(self, sens):