Merge pull request #14823 from panyx0718/fix

fix control_flow ops in outs
ce_debug
Xin Pan 6 years ago committed by GitHub
commit ed9cdb56f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1084,19 +1084,15 @@ class Block(object):
raise ValueError("var %s not in this block" % name) raise ValueError("var %s not in this block" % name)
return v return v
def _var_recursive(self, name): def _find_var_recursive(self, name):
""" """
Get a Variable by name from this block recursively. Get a Variable by name from this block recursively.
Args: Args:
name(str): the Variable's name. name(str): the Variable's name.
Raises:
ValueError: this block and this parent block doesn't
have a Variable with the giving name.
Returns: Returns:
Variable: the Variable with the giving name. Variable: the Variable with the giving name. Or None if not found.
""" """
frontier = list() frontier = list()
visited = set() visited = set()
@ -1122,8 +1118,27 @@ class Block(object):
frontier.append(prog.block(cur.forward_block_idx)) frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur)) visited.add(id(cur))
return None
raise ValueError("Var {0} is not found recursively".format(name)) def _var_recursive(self, name):
"""
Get a Variable by name from this block recursively.
Args:
name(str): the Variable's name.
Raises:
ValueError: this block and this parent block doesn't
have a Variable with the giving name.
Returns:
Variable: the Variable with the giving name.
"""
var = self._find_var_recursive(name)
if var:
return var
else:
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())

@ -717,8 +717,9 @@ class While(object):
out_vars = [] out_vars = []
for inner_out_name in inner_outputs: for inner_out_name in inner_outputs:
if inner_out_name in parent_block.vars: inner_var = parent_block._find_var_recursive(inner_out_name)
out_vars.append(parent_block.var(inner_out_name)) if inner_var:
out_vars.append(inner_var)
step_scope = parent_block.create_var( step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES) type=core.VarDesc.VarType.STEP_SCOPES)
@ -1264,10 +1265,11 @@ class ConditionalBlock(object):
if each_name not in input_set if each_name not in input_set
] ]
out_list = [ out_list = []
parent_block.var(var_name) for var_name in parent_block.vars for inner_out_name in intermediate:
if var_name in intermediate inner_var = parent_block._find_var_recursive(inner_out_name)
] if inner_var:
out_list.append(inner_var)
step_scope = parent_block.create_var( step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES) type=core.VarDesc.VarType.STEP_SCOPES)

Loading…
Cancel
Save