!14065 New add grad accumulation for network.

From: @linqingke
Reviewed-by: @guoqi1024,@xu-yfei
Signed-off-by: @guoqi1024
pull/14065/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 72bad339e7

@ -27,23 +27,86 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
constexpr auto kSingleInputIndex = 1; constexpr auto kSingleInputIndex = 1;
constexpr auto kIsolatedDependRealInputIndex = 0;
constexpr auto kIsolatedDependVirtualInputIndex = 1;
namespace { namespace {
CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<AnfNodePtr> &new_depend_inputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend = nullptr;
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_abstract(cnode->abstract());
new_depend->set_scope(cnode->scope());
} else {
new_depend = kernel_graph->NewCNode(cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
}
func_graph->manager()->Replace(cnode, new_depend);
return new_depend;
}
CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) {
return nullptr;
}
auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) {
return nullptr;
}
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
if (!real_input_op->isa<CNode>()) {
return nullptr;
}
auto real_input_cnode = real_input_op->cast<CNodePtr>();
return real_input_cnode;
}
AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const CNodePtr &eliminate_node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(eliminate_node);
auto replace_node = eliminate_node->input(kSingleInputIndex);
std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
auto new_node = new_cnode->cast<AnfNodePtr>();
return new_node;
}
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return nullptr; return nullptr;
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
string op_name = AnfAlgo::GetCNodeName(cnode); auto replace_cnode = cnode;
// Process updatestate and depend as isolated node env.
auto isolated_cnode = CheckIsolatedVirtualNode(replace_cnode);
if (isolated_cnode != nullptr) {
replace_cnode = isolated_cnode;
}
string op_name = AnfAlgo::GetCNodeName(replace_cnode);
// Currently we only eliminate transdata or cast nodes. // Currently we only eliminate transdata or cast nodes.
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
return nullptr; return nullptr;
} }
if (!IsNotRealUsedByOthers(func_graph, cnode)) { if (!IsNotRealUsedByOthers(func_graph, replace_cnode)) {
return nullptr; return nullptr;
} }
CheckCNodeInputSize(cnode, kSingleInputIndex); CheckCNodeInputSize(replace_cnode, kSingleInputIndex);
if (isolated_cnode != nullptr) {
auto new_depend_node = EliminateIsolatedVirtualNodeInput(func_graph, cnode, replace_cnode);
return new_depend_node;
}
return cnode->input(kSingleInputIndex); return cnode->input(kSingleInputIndex);
} }
@ -137,20 +200,10 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
return nullptr; return nullptr;
} }
new_depend_inputs[replace_index] = replace_node; new_depend_inputs[replace_index] = replace_node;
// Because depend's input has been changed, so a new depend(UpdateState) node will be created to replaced the old one. auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); if (new_depend == nullptr) {
CNodePtr new_depend = nullptr; return nullptr;
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_abstract(depend_cnode->abstract());
new_depend->set_scope(depend_cnode->scope());
} else {
new_depend = kernel_graph->NewCNode(depend_cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
} }
func_graph->manager()->Replace(depend_cnode, new_depend);
return nullptr; return nullptr;
} }

@ -107,6 +107,20 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
// a = Load(para1, u1) // a = Load(para1, u1)
// ... // ...
// b = Load(para1, u2) // b = Load(para1, u2)
// u3 = UpdateState(u2, b)
//==>
// delete the UpdateState
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user,
const AnfNodePtr &load) {
const auto &load_cnode = load->cast<CNodePtr>();
const auto &u = load_cnode->input(2);
manager->Replace(load_user, u);
}
// Pattern2======================================
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
// t = make_tuple(x, b) // t = make_tuple(x, b)
// u3 = UpdateState(u2, t) // u3 = UpdateState(u2, t)
//==> //==>
@ -127,7 +141,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr
manager->Replace(make_tuple, other_input); manager->Replace(make_tuple, other_input);
} }
// Pattern2====================================== // Pattern3======================================
// a = Load(para1, u1) // a = Load(para1, u1)
// ... // ...
// b = Load(para1, u2) // b = Load(para1, u2)
@ -153,6 +167,11 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap
void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
auto load_users = manager->node_users()[load]; auto load_users = manager->node_users()[load];
for (const auto &load_user : load_users) { for (const auto &load_user : load_users) {
// Pattern1
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
DeleteLoadUserUpdateState(manager, load_user.first, load);
continue;
}
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
const auto &make_tuple = load_user.first->cast<CNodePtr>(); const auto &make_tuple = load_user.first->cast<CNodePtr>();
auto &maketuple_users = manager->node_users()[make_tuple]; auto &maketuple_users = manager->node_users()[make_tuple];
@ -161,12 +180,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
if (!maketuple_as_input_of_update) { if (!maketuple_as_input_of_update) {
continue; continue;
} }
// Pattern1 // Pattern2
if (make_tuple->size() == 3) { if (make_tuple->size() == 3) {
DeleteLoadUserMakeTuple(manager, make_tuple, load); DeleteLoadUserMakeTuple(manager, make_tuple, load);
continue; continue;
} }
// Pattern2 // Pattern3
if (make_tuple->size() > 3) { if (make_tuple->size() > 3) {
ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
} }

@ -17,10 +17,11 @@ from types import FunctionType, MethodType
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.context import ParallelMode from mindspore.context import ParallelMode, get_auto_parallel_context
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...common.tensor import Tensor
from ...ops import composite as C from ...ops import composite as C
from ...ops import functional as F from ...ops import functional as F
from ...ops import operations as P from ...ops import operations as P
@ -62,6 +63,19 @@ def _tensors_cast_datatype(datatype, param):
""" """
return F.cast(param, datatype) return F.cast(param, datatype)
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
"""Apply gradient accumulation to cumulative grad."""
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
@_gradient_clear_op.register("Tensor")
def _clear_grad(cumulative_grad):
zero_grad = P.ZerosLike()(cumulative_grad)
return F.assign(cumulative_grad, zero_grad)
class WithLossCell(Cell): class WithLossCell(Cell):
r""" r"""
@ -347,15 +361,28 @@ class TrainOneStepCell(Cell):
self.mean = _get_gradients_mean() self.mean = _get_gradients_mean()
self.degree = _get_device_num() self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
self.use_grad_accumulation = False
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
self.use_grad_accumulation = True
if self.use_grad_accumulation:
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
if self.max_accumulation_step <= 1:
self.max_accumulation_step = 1
self.use_grad_accumulation = False
if self.use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
def construct(self, *inputs): def construct(self, *inputs):
weights = self.weights weights = self.weights
loss = self.network(*inputs) loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.depend(sens, loss)
grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.optimizer(grads))
return loss return loss
@ -557,3 +584,34 @@ class _BroadCastCell(Cell):
params = self.broadcast(params) params = self.broadcast(params)
new_params = self.map_(F.partial(_cast_datatype), datatypes, params) new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
return new_params return new_params
class GradientAccumulation(Cell):
"""
After accumulating the gradients of multiple steps, call to optimize its update.
Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer(Cell):Optimizer used.
"""
def __init__(self, max_accumulation_step, optimizer):
super(GradientAccumulation, self).__init__()
self._max_accumulation_step = max_accumulation_step
self.optimizer = optimizer
self.weights = optimizer.parameters
self.hyper_map = C.HyperMap()
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
def construct(self, loss, grads):
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
self._grad_accumulation, grads))
self._accumulation_step += 1
if self._accumulation_step >= self._max_accumulation_step:
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
self._accumulation_step = 0
if self._accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
return loss

@ -319,7 +319,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
overflow = self.process_loss_scale(cond) overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize # if there is no overflow, do optimize
if not overflow: if not overflow:
loss = F.depend(loss, self.optimizer(grads)) if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.optimizer(grads))
return loss, cond, scaling_sens return loss, cond, scaling_sens
def set_sense_scale(self, sens): def set_sense_scale(self, sens):

Loading…
Cancel
Save