|
|
|
@ -14,10 +14,10 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict, OrderedDict, Callable
|
|
|
|
|
from collections import defaultdict, MutableSet
|
|
|
|
|
from .. import core
|
|
|
|
|
from ... import compat as cpt
|
|
|
|
|
from ..framework import Program, default_main_program, Parameter, Variable
|
|
|
|
|
from ..framework import Program, default_main_program, Parameter, Variable, core
|
|
|
|
|
from ..backward import _rename_arg_
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from six.moves import range
|
|
|
|
@ -44,17 +44,82 @@ SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"),
|
|
|
|
|
PRINT_LOG = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OrderedSet(MutableSet):
|
|
|
|
|
def __init__(self, iterable=None):
|
|
|
|
|
self.end = end = []
|
|
|
|
|
end += [None, end, end] # sentinel node for doubly linked list
|
|
|
|
|
self.map = {} # key --> [key, prev, next]
|
|
|
|
|
if iterable is not None:
|
|
|
|
|
self |= iterable
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.map)
|
|
|
|
|
|
|
|
|
|
def __contains__(self, key):
|
|
|
|
|
return key in self.map
|
|
|
|
|
|
|
|
|
|
def add(self, key):
|
|
|
|
|
if key not in self.map:
|
|
|
|
|
end = self.end
|
|
|
|
|
curr = end[1]
|
|
|
|
|
curr[2] = end[1] = self.map[key] = [key, curr, end]
|
|
|
|
|
|
|
|
|
|
def update(self, other):
|
|
|
|
|
for e in other:
|
|
|
|
|
self.add(e)
|
|
|
|
|
|
|
|
|
|
def discard(self, key):
|
|
|
|
|
if key in self.map:
|
|
|
|
|
key, prev, next = self.map.pop(key)
|
|
|
|
|
prev[2] = next
|
|
|
|
|
next[1] = prev
|
|
|
|
|
|
|
|
|
|
def remove(self, key):
|
|
|
|
|
self.discard(key)
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
end = self.end
|
|
|
|
|
curr = end[2]
|
|
|
|
|
while curr is not end:
|
|
|
|
|
yield curr[0]
|
|
|
|
|
curr = curr[2]
|
|
|
|
|
|
|
|
|
|
def __reversed__(self):
|
|
|
|
|
end = self.end
|
|
|
|
|
curr = end[1]
|
|
|
|
|
while curr is not end:
|
|
|
|
|
yield curr[0]
|
|
|
|
|
curr = curr[1]
|
|
|
|
|
|
|
|
|
|
def pop(self, last=True):
|
|
|
|
|
if not self:
|
|
|
|
|
raise KeyError('set is empty')
|
|
|
|
|
key = self.end[1][0] if last else self.end[2][0]
|
|
|
|
|
self.discard(key)
|
|
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
if not self:
|
|
|
|
|
return '%s()' % (self.__class__.__name__, )
|
|
|
|
|
return '%s(%r)' % (self.__class__.__name__, list(self))
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
if isinstance(other, OrderedSet):
|
|
|
|
|
return len(self) == len(other) and list(self) == list(other)
|
|
|
|
|
return set(self) == set(other)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlFlowGraph(object):
|
|
|
|
|
def __init__(self, program, ops, forward_num, skip_opt):
|
|
|
|
|
self._program = program
|
|
|
|
|
self._ops = ops
|
|
|
|
|
self._forward_num = forward_num
|
|
|
|
|
self._successors = defaultdict(set)
|
|
|
|
|
self._presuccessors = defaultdict(set)
|
|
|
|
|
self._uses = defaultdict(set)
|
|
|
|
|
self._defs = defaultdict(set)
|
|
|
|
|
self._live_in = defaultdict(set)
|
|
|
|
|
self._live_out = defaultdict(set)
|
|
|
|
|
self._successors = defaultdict(OrderedSet)
|
|
|
|
|
self._presuccessors = defaultdict(OrderedSet)
|
|
|
|
|
self._uses = defaultdict(OrderedSet)
|
|
|
|
|
self._defs = defaultdict(OrderedSet)
|
|
|
|
|
self._live_in = defaultdict(OrderedSet)
|
|
|
|
|
self._live_out = defaultdict(OrderedSet)
|
|
|
|
|
self._skip_opt = skip_opt
|
|
|
|
|
self.pool = []
|
|
|
|
|
|
|
|
|
@ -116,7 +181,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
# NOTE: must sort the in_diff set for cases that get different cache var.
|
|
|
|
|
# FIXME(typhoonzero): maybe use a "sorted set" is better than this.
|
|
|
|
|
can_optimize = [
|
|
|
|
|
x for x in sorted(list(in_diff))
|
|
|
|
|
x for x in in_diff
|
|
|
|
|
if self._check_var_validity(block_desc, x, is_forward)
|
|
|
|
|
]
|
|
|
|
|
if can_optimize:
|
|
|
|
@ -224,7 +289,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
if self.pool:
|
|
|
|
|
# NOTE: must sort the in_diff set for cases that get different cache var.
|
|
|
|
|
defs_can_optimize = [
|
|
|
|
|
x for x in sorted(list(self._defs[i]))
|
|
|
|
|
x for x in self._defs[i]
|
|
|
|
|
if self._check_var_validity(block_desc, x, is_forward)
|
|
|
|
|
]
|
|
|
|
|
out_pair = [
|
|
|
|
@ -381,7 +446,19 @@ def _get_cfgs(input_program):
|
|
|
|
|
return cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
|
|
|
|
|
def _is_opt_role_op(op):
|
|
|
|
|
op_maker = core.op_proto_and_checker_maker
|
|
|
|
|
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
|
|
|
|
|
if op_maker.kOpRoleAttrName() in op.attr_names and \
|
|
|
|
|
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def memory_optimize(input_program,
|
|
|
|
|
skip_opt_set=None,
|
|
|
|
|
print_log=False,
|
|
|
|
|
level=0,
|
|
|
|
|
skip_grads=False):
|
|
|
|
|
"""Optimize memory by reusing var memory.
|
|
|
|
|
|
|
|
|
|
Note: it doesn't not support subblock nested in subblock.
|
|
|
|
@ -398,6 +475,19 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
|
|
|
|
|
raise ValueError("only support opt_level 0 or 1.")
|
|
|
|
|
global PRINT_LOG
|
|
|
|
|
PRINT_LOG = print_log
|
|
|
|
|
if skip_grads:
|
|
|
|
|
grad_set = set()
|
|
|
|
|
OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
|
|
|
|
|
for op in input_program.global_block().ops:
|
|
|
|
|
if _is_opt_role_op(op):
|
|
|
|
|
if op.attr(OP_ROLE_VAR):
|
|
|
|
|
grad_name = op.attr(OP_ROLE_VAR)[1]
|
|
|
|
|
grad_set.add(grad_name)
|
|
|
|
|
if not skip_opt_set:
|
|
|
|
|
skip_opt_set = grad_set
|
|
|
|
|
else:
|
|
|
|
|
skip_opt_set.update(grad_set)
|
|
|
|
|
|
|
|
|
|
cfgs = _get_cfgs(input_program)
|
|
|
|
|
for cfg in cfgs:
|
|
|
|
|
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
|
|
|
|
|