|
|
|
@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
|
|
|
|
|
|
|
|
|
|
TRUE_FUNC_PREFIX = 'true_fn'
|
|
|
|
|
FALSE_FUNC_PREFIX = 'false_fn'
|
|
|
|
@ -145,15 +146,24 @@ class IfElseTransformer(gast.NodeTransformer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeTestTransformer(gast.NodeTransformer):
|
|
|
|
|
def __init__(self, ast_node, compare_nodes_with_tensor=None):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
ast_node,
|
|
|
|
|
compare_nodes_with_tensor=None,
|
|
|
|
|
node_to_wrapper_map=None):
|
|
|
|
|
if compare_nodes_with_tensor is None:
|
|
|
|
|
compare_nodes_with_tensor = set()
|
|
|
|
|
self.ast_root = ast_node
|
|
|
|
|
self._compare_nodes_with_tensor = compare_nodes_with_tensor
|
|
|
|
|
if node_to_wrapper_map is None:
|
|
|
|
|
node_to_wrapper_map = {}
|
|
|
|
|
self.node_to_wrapper_map = node_to_wrapper_map
|
|
|
|
|
self._new_assign_nodes = []
|
|
|
|
|
|
|
|
|
|
def transform(self):
|
|
|
|
|
return self.visit(self.ast_root)
|
|
|
|
|
node = self.ast_root
|
|
|
|
|
if not is_candidate_node(node):
|
|
|
|
|
return self._create_cast_node(node)
|
|
|
|
|
return self.visit(node)
|
|
|
|
|
|
|
|
|
|
def visit_Call(self, node):
|
|
|
|
|
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
|
|
|
|
@ -182,8 +192,11 @@ class NodeTestTransformer(gast.NodeTransformer):
|
|
|
|
|
def visit_BoolOp(self, node):
|
|
|
|
|
for i, child in enumerate(node.values):
|
|
|
|
|
if not is_candidate_node(child):
|
|
|
|
|
node.values[i] = self._create_bool_node(child)
|
|
|
|
|
continue
|
|
|
|
|
node_wrapper = self.node_to_wrapper_map.get(child, None)
|
|
|
|
|
if node_wrapper and node_wrapper.node_var_type & NodeVarType.TENSOR_TYPES:
|
|
|
|
|
node.values[i] = self._create_cast_node(child)
|
|
|
|
|
else:
|
|
|
|
|
node.values[i] = self._create_bool_node(child)
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
new_node = self._create_logic_node(node)
|
|
|
|
|
return new_node
|
|
|
|
@ -195,10 +208,19 @@ class NodeTestTransformer(gast.NodeTransformer):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _create_cast_node(self, node):
|
|
|
|
|
template = "fluid.layers.cast(x={}, dtype='bool')"
|
|
|
|
|
|
|
|
|
|
return self._create_node_with_api_template(node, template)
|
|
|
|
|
|
|
|
|
|
def _create_bool_node(self, node):
|
|
|
|
|
template = "fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool({}))"
|
|
|
|
|
|
|
|
|
|
return self._create_node_with_api_template(node, template)
|
|
|
|
|
|
|
|
|
|
def _create_node_with_api_template(self, node, template):
|
|
|
|
|
node_code = ast_to_source_code(node)
|
|
|
|
|
new_node_str = "fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool({}))".format(
|
|
|
|
|
node_code)
|
|
|
|
|
new_node_str = template.format(node_code)
|
|
|
|
|
# gast.parse return Module(body=[expr(value=...)])
|
|
|
|
|
new_node = gast.parse(new_node_str).body[0].value
|
|
|
|
|
bool_tensor_name = unique_name.generate(PLAIN_TENSOR_PREFIX)
|
|
|
|
@ -258,7 +280,8 @@ class IfConditionVisitor(object):
|
|
|
|
|
self.static_analysis_visitor = static_analysis_visitor
|
|
|
|
|
self.visitor = IsControlFlowVisitor(node, static_analysis_visitor,
|
|
|
|
|
node_var_type_map)
|
|
|
|
|
self.transformer = NodeTestTransformer(node)
|
|
|
|
|
self.transformer = NodeTestTransformer(
|
|
|
|
|
node, node_to_wrapper_map=self.visitor.node_to_wrapper_map)
|
|
|
|
|
self.compare_nodes_with_tensor = set()
|
|
|
|
|
self._is_control_flow_if = False
|
|
|
|
|
|
|
|
|
|