|
|
|
@ -31,6 +31,8 @@ dtype_to_size = {
|
|
|
|
|
|
|
|
|
|
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
|
|
|
|
|
|
|
|
|
|
PRINT_LOG = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlFlowGraph(object):
|
|
|
|
|
def __init__(self, Program, ops, forward_num, skip_opt):
|
|
|
|
@ -170,7 +172,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
block_desc, cache_var, is_forward).dtype()
|
|
|
|
|
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
|
|
|
|
|
# and dtype_to_size[cache_dtype]
|
|
|
|
|
if x_dtype == cache_dtype:
|
|
|
|
|
if x_dtype == cache_dtype and PRINT_LOG:
|
|
|
|
|
print(("Hit Cache !!!! cache pool index "
|
|
|
|
|
"is %d, var name is %s, "
|
|
|
|
|
"cached var name is %s, "
|
|
|
|
@ -277,7 +279,9 @@ def _get_cfgs(input_program):
|
|
|
|
|
return cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def memory_optimize(input_program):
|
|
|
|
|
def memory_optimize(input_program, print_log=False):
|
|
|
|
|
global PRINT_LOG
|
|
|
|
|
PRINT_LOG = print_log
|
|
|
|
|
cfgs = _get_cfgs(input_program)
|
|
|
|
|
for cfg in cfgs:
|
|
|
|
|
cfg.memory_optimize()
|
|
|
|
|