!14065 New add grad accumulation for network.

From: @linqingke
Reviewed-by: @guoqi1024,@xu-yfei
Signed-off-by: @guoqi1024
pull/14065/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 72bad339e7

@ -27,23 +27,86 @@
namespace mindspore {
namespace opt {
constexpr auto kSingleInputIndex = 1;
constexpr auto kIsolatedDependRealInputIndex = 0;
constexpr auto kIsolatedDependVirtualInputIndex = 1;
namespace {
CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<AnfNodePtr> &new_depend_inputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend = nullptr;
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_abstract(cnode->abstract());
new_depend->set_scope(cnode->scope());
} else {
new_depend = kernel_graph->NewCNode(cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
}
func_graph->manager()->Replace(cnode, new_depend);
return new_depend;
}
CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) {
return nullptr;
}
auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) {
return nullptr;
}
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
if (!real_input_op->isa<CNode>()) {
return nullptr;
}
auto real_input_cnode = real_input_op->cast<CNodePtr>();
return real_input_cnode;
}
AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const CNodePtr &eliminate_node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(eliminate_node);
auto replace_node = eliminate_node->input(kSingleInputIndex);
std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
auto new_node = new_cnode->cast<AnfNodePtr>();
return new_node;
}
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
string op_name = AnfAlgo::GetCNodeName(cnode);
auto replace_cnode = cnode;
// Process updatestate and depend as isolated node env.
auto isolated_cnode = CheckIsolatedVirtualNode(replace_cnode);
if (isolated_cnode != nullptr) {
replace_cnode = isolated_cnode;
}
string op_name = AnfAlgo::GetCNodeName(replace_cnode);
// Currently we only eliminate transdata or cast nodes.
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
return nullptr;
}
if (!IsNotRealUsedByOthers(func_graph, cnode)) {
if (!IsNotRealUsedByOthers(func_graph, replace_cnode)) {
return nullptr;
}
CheckCNodeInputSize(cnode, kSingleInputIndex);
CheckCNodeInputSize(replace_cnode, kSingleInputIndex);
if (isolated_cnode != nullptr) {
auto new_depend_node = EliminateIsolatedVirtualNodeInput(func_graph, cnode, replace_cnode);
return new_depend_node;
}
return cnode->input(kSingleInputIndex);
}
@ -137,20 +200,10 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
return nullptr;
}
new_depend_inputs[replace_index] = replace_node;
// Because depend's input has been changed, so a new depend(UpdateState) node will be created to replaced the old one.
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend = nullptr;
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_abstract(depend_cnode->abstract());
new_depend->set_scope(depend_cnode->scope());
} else {
new_depend = kernel_graph->NewCNode(depend_cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs);
if (new_depend == nullptr) {
return nullptr;
}
func_graph->manager()->Replace(depend_cnode, new_depend);
return nullptr;
}

@ -107,6 +107,20 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
// u3 = UpdateState(u2, b)
//==>
// delete the UpdateState
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user,
const AnfNodePtr &load) {
const auto &load_cnode = load->cast<CNodePtr>();
const auto &u = load_cnode->input(2);
manager->Replace(load_user, u);
}
// Pattern2======================================
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
// t = make_tuple(x, b)
// u3 = UpdateState(u2, t)
//==>
@ -127,7 +141,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr
manager->Replace(make_tuple, other_input);
}
// Pattern2======================================
// Pattern3======================================
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
@ -153,6 +167,11 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap
void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
auto load_users = manager->node_users()[load];
for (const auto &load_user : load_users) {
// Pattern1
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
DeleteLoadUserUpdateState(manager, load_user.first, load);
continue;
}
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
const auto &make_tuple = load_user.first->cast<CNodePtr>();
auto &maketuple_users = manager->node_users()[make_tuple];
@ -161,12 +180,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
if (!maketuple_as_input_of_update) {
continue;
}
// Pattern1
// Pattern2
if (make_tuple->size() == 3) {
DeleteLoadUserMakeTuple(manager, make_tuple, load);
continue;
}
// Pattern2
// Pattern3
if (make_tuple->size() > 3) {
ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
}

@ -17,10 +17,11 @@ from types import FunctionType, MethodType
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.context import ParallelMode, get_auto_parallel_context
from mindspore._checkparam import Validator as validator
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...common.tensor import Tensor
from ...ops import composite as C
from ...ops import functional as F
from ...ops import operations as P
@ -62,6 +63,19 @@ def _tensors_cast_datatype(datatype, param):
"""
return F.cast(param, datatype)
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
"""Apply gradient accumulation to cumulative grad."""
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
@_gradient_clear_op.register("Tensor")
def _clear_grad(cumulative_grad):
zero_grad = P.ZerosLike()(cumulative_grad)
return F.assign(cumulative_grad, zero_grad)
class WithLossCell(Cell):
r"""
@ -347,15 +361,28 @@ class TrainOneStepCell(Cell):
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
self.use_grad_accumulation = False
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
self.use_grad_accumulation = True
if self.use_grad_accumulation:
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
if self.max_accumulation_step <= 1:
self.max_accumulation_step = 1
self.use_grad_accumulation = False
if self.use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.depend(sens, loss)
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.optimizer(grads))
return loss
@ -557,3 +584,34 @@ class _BroadCastCell(Cell):
params = self.broadcast(params)
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
return new_params
class GradientAccumulation(Cell):
"""
After accumulating the gradients of multiple steps, call to optimize its update.
Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer(Cell):Optimizer used.
"""
def __init__(self, max_accumulation_step, optimizer):
super(GradientAccumulation, self).__init__()
self._max_accumulation_step = max_accumulation_step
self.optimizer = optimizer
self.weights = optimizer.parameters
self.hyper_map = C.HyperMap()
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
def construct(self, loss, grads):
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
self._grad_accumulation, grads))
self._accumulation_step += 1
if self._accumulation_step >= self._max_accumulation_step:
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
self._accumulation_step = 0
if self._accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
return loss

@ -319,7 +319,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.optimizer(grads))
return loss, cond, scaling_sens
def set_sense_scale(self, sens):

Loading…
Cancel
Save