|
|
|
@ -171,7 +171,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
self._live_out[i] |= self._live_in[s]
|
|
|
|
|
self._live_in[i] = self._uses[i] | (
|
|
|
|
|
self._live_out[i] - self._defs[i])
|
|
|
|
|
if live_in[i] != self._live_in[i]:
|
|
|
|
|
if live_in[i] != set(self._live_in[i]):
|
|
|
|
|
for d in self._presuccessors[i]:
|
|
|
|
|
worklist.append(d)
|
|
|
|
|
|
|
|
|
@ -321,8 +321,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
|
|
|
|
|
if not compare_shape(x_shape, cache_shape, level):
|
|
|
|
|
continue
|
|
|
|
|
# TODO(qijun): actually, we should compare
|
|
|
|
|
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
|
|
|
|
|
# TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
|
|
|
|
|
if x_dtype != cache_dtype:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
@ -487,7 +486,6 @@ def memory_optimize(input_program,
|
|
|
|
|
skip_opt_set = grad_set
|
|
|
|
|
else:
|
|
|
|
|
skip_opt_set.update(grad_set)
|
|
|
|
|
|
|
|
|
|
cfgs = _get_cfgs(input_program)
|
|
|
|
|
for cfg in cfgs:
|
|
|
|
|
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
|
|
|
|
|