|
|
|
@ -157,9 +157,11 @@ class ControlFlowGraph(object):
|
|
|
|
|
if op.type() == "fill_constant" and op.attr("force_cpu") == True:
|
|
|
|
|
self._skip_opt.update(op.output_arg_names())
|
|
|
|
|
|
|
|
|
|
def release_memory(self):
|
|
|
|
|
def release_memory(self, skip_opt_set=None):
|
|
|
|
|
self._dataflow_analyze()
|
|
|
|
|
self._update_skip_opt_set()
|
|
|
|
|
if skip_opt_set:
|
|
|
|
|
self._skip_opt.update(skip_opt_set)
|
|
|
|
|
fwd_id = 0
|
|
|
|
|
bwd_id = 0
|
|
|
|
|
for i in range(self.op_size):
|
|
|
|
@ -183,7 +185,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
else:
|
|
|
|
|
bwd_id += 1
|
|
|
|
|
|
|
|
|
|
def memory_optimize(self, level=0):
|
|
|
|
|
def memory_optimize(self, skip_opt_set=None, level=0):
|
|
|
|
|
def compare_shape(x_shape, cache_shape, opt_level):
|
|
|
|
|
if opt_level == 0:
|
|
|
|
|
return x_shape == cache_shape
|
|
|
|
@ -200,6 +202,9 @@ class ControlFlowGraph(object):
|
|
|
|
|
|
|
|
|
|
self._dataflow_analyze()
|
|
|
|
|
self._update_skip_opt_set()
|
|
|
|
|
# update skip set to meet users' demand
|
|
|
|
|
if skip_opt_set:
|
|
|
|
|
self._skip_opt.update(skip_opt_set)
|
|
|
|
|
self.pool = []
|
|
|
|
|
for i in range(self.op_size):
|
|
|
|
|
op = self._ops[i]
|
|
|
|
@ -358,7 +363,7 @@ def _get_cfgs(input_program):
|
|
|
|
|
return cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def memory_optimize(input_program, print_log=False, level=0):
|
|
|
|
|
def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
|
|
|
|
|
"""Optimize memory by reusing var memory.
|
|
|
|
|
|
|
|
|
|
Note: it doesn't not support subblock nested in subblock.
|
|
|
|
@ -374,10 +379,10 @@ def memory_optimize(input_program, print_log=False, level=0):
|
|
|
|
|
PRINT_LOG = print_log
|
|
|
|
|
cfgs = _get_cfgs(input_program)
|
|
|
|
|
for cfg in cfgs:
|
|
|
|
|
cfg.memory_optimize(level)
|
|
|
|
|
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def release_memory(input_program):
|
|
|
|
|
def release_memory(input_program, skip_opt_set=None):
|
|
|
|
|
cfgs = _get_cfgs(input_program)
|
|
|
|
|
for cfg in cfgs:
|
|
|
|
|
cfg.release_memory()
|
|
|
|
|
cfg.release_memory(skip_opt_set=skip_opt_set)
|
|
|
|
|