|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from .. import core
|
|
|
|
|
from .. import compat
|
|
|
|
|
from ..framework import Program, default_main_program, Parameter
|
|
|
|
|
from ..backward import _rename_arg_
|
|
|
|
|
from functools import reduce
|
|
|
|
@ -125,15 +126,15 @@ class ControlFlowGraph(object):
|
|
|
|
|
|
|
|
|
|
def _has_var(self, block_desc, var_name, is_forward):
|
|
|
|
|
if is_forward:
|
|
|
|
|
return block_desc.has_var(str(var_name))
|
|
|
|
|
return block_desc.has_var(cpt.to_bytes(var_name))
|
|
|
|
|
else:
|
|
|
|
|
return block_desc.has_var_recursive(str(var_name))
|
|
|
|
|
return block_desc.has_var_recursive(cpt.to_bytes(var_name))
|
|
|
|
|
|
|
|
|
|
def _find_var(self, block_desc, var_name, is_forward):
|
|
|
|
|
if is_forward:
|
|
|
|
|
return block_desc.find_var(str(var_name))
|
|
|
|
|
return block_desc.find_var(cpt.to_bytes(var_name))
|
|
|
|
|
else:
|
|
|
|
|
return block_desc.find_var_recursive(str(var_name))
|
|
|
|
|
return block_desc.find_var_recursive(cpt.to_bytes(var_name))
|
|
|
|
|
|
|
|
|
|
def _check_var_validity(self, block_desc, x, is_forward):
|
|
|
|
|
if str(x) == "@EMPTY@":
|
|
|
|
@ -258,7 +259,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
# Rename the var to the cache var already with
|
|
|
|
|
# memory allocated in order to reuse the memory.
|
|
|
|
|
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
|
|
|
|
|
self._program.block(block_desc.id).var(str(
|
|
|
|
|
self._program.block(block_desc.id).var(cpt.to_literal_str(
|
|
|
|
|
x)).desc = self._find_var(block_desc, cache_var,
|
|
|
|
|
is_forward)
|
|
|
|
|
self._update_graph(x, cache_var, begin_idx=i)
|
|
|
|
|