|
|
|
@ -17,74 +17,9 @@ from __future__ import print_function
|
|
|
|
|
import astor
|
|
|
|
|
import gast
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
|
|
|
|
|
from paddle.fluid.framework import core, Variable
|
|
|
|
|
from paddle.fluid.layers import array_length, array_read, array_write, create_array
|
|
|
|
|
from paddle.fluid.layers import assign, fill_constant, slice
|
|
|
|
|
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(liym27): A better way to slice tensor array.
|
|
|
|
|
# Maybe support start == end for slice op.
|
|
|
|
|
def slice_tensor_array(array, start, end):
|
|
|
|
|
def true_fn():
|
|
|
|
|
null_array = create_array("float32")
|
|
|
|
|
return null_array
|
|
|
|
|
|
|
|
|
|
def false_fn(array, start, end):
|
|
|
|
|
new_array = slice(array, starts=[start], ends=[end], axes=[0])
|
|
|
|
|
return new_array
|
|
|
|
|
|
|
|
|
|
new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end))
|
|
|
|
|
return new_array
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_array_pop(array, idx):
|
|
|
|
|
assert isinstance(idx, int)
|
|
|
|
|
|
|
|
|
|
def cond(i, new_array):
|
|
|
|
|
return less_than(i, arr_len)
|
|
|
|
|
|
|
|
|
|
def body(i, new_array):
|
|
|
|
|
item = array_read(array=array, i=i)
|
|
|
|
|
array_write(item, array_length(new_array), new_array)
|
|
|
|
|
i = increment(i)
|
|
|
|
|
return i, new_array
|
|
|
|
|
|
|
|
|
|
arr_len = array_length(array)
|
|
|
|
|
if idx < 0:
|
|
|
|
|
idx = idx + arr_len
|
|
|
|
|
else:
|
|
|
|
|
idx = fill_constant(shape=[1], dtype="int64", value=idx)
|
|
|
|
|
|
|
|
|
|
pop_item = array_read(array, idx)
|
|
|
|
|
|
|
|
|
|
new_array = slice_tensor_array(array, 0, idx)
|
|
|
|
|
i = idx + 1
|
|
|
|
|
_, new_array = while_loop(cond, body, [i, new_array])
|
|
|
|
|
assign(input=new_array, output=array)
|
|
|
|
|
|
|
|
|
|
return pop_item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_list_pop(target, idx=None):
|
|
|
|
|
"""
|
|
|
|
|
Convert list pop.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if idx is None:
|
|
|
|
|
idx = -1
|
|
|
|
|
|
|
|
|
|
is_variable = isinstance(target, Variable)
|
|
|
|
|
if is_variable:
|
|
|
|
|
is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
|
|
|
|
|
if is_variable and is_tensor_array:
|
|
|
|
|
result = tensor_array_pop(target, idx)
|
|
|
|
|
else:
|
|
|
|
|
result = target.pop(idx)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ListTransformer(gast.NodeTransformer):
|
|
|
|
@ -117,7 +52,7 @@ class ListTransformer(gast.NodeTransformer):
|
|
|
|
|
if isinstance(node.func, gast.Attribute):
|
|
|
|
|
func_name = node.func.attr
|
|
|
|
|
if func_name == "pop":
|
|
|
|
|
node = self._replace_list_pop(node)
|
|
|
|
|
node = self._replace_pop(node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_Assign(self, node):
|
|
|
|
@ -283,20 +218,36 @@ class ListTransformer(gast.NodeTransformer):
|
|
|
|
|
del self.list_name_to_updated[target_id]
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _replace_list_pop(self, node):
|
|
|
|
|
def _replace_pop(self, node):
|
|
|
|
|
"""
|
|
|
|
|
Replace a pop statement for a list or dict.
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
|
|
list_a = [0,1,2,3,4]
|
|
|
|
|
x = list_a.pop() # --> convert_pop(list_a)
|
|
|
|
|
y = list_a.pop(1) # --> convert_pop(list_a, 1)
|
|
|
|
|
|
|
|
|
|
dict_a = {"red":0, "blue":1, "yellow":2}
|
|
|
|
|
m = dict_a.pop("red") # --> convert_pop(dict_a, "red")
|
|
|
|
|
n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(node, gast.Call)
|
|
|
|
|
assert isinstance(node.func, gast.Attribute)
|
|
|
|
|
|
|
|
|
|
target_node = node.func.value
|
|
|
|
|
target_str = ast_to_source_code(target_node).strip()
|
|
|
|
|
|
|
|
|
|
if node.args:
|
|
|
|
|
idx_node = node.args[0]
|
|
|
|
|
idx_str = ast_to_source_code(idx_node).strip()
|
|
|
|
|
args_str = [ast_to_source_code(arg).strip() for arg in node.args]
|
|
|
|
|
|
|
|
|
|
# NOTE(liym27):
|
|
|
|
|
# 1. pop stmt for a list if len(args_str) == 0
|
|
|
|
|
# 2. pop stmt for a list or dict if len(args_str) == 1
|
|
|
|
|
# 3. pop stmt for a dict if len(args_str) == 2
|
|
|
|
|
if len(args_str) <= 2:
|
|
|
|
|
new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\
|
|
|
|
|
.format(target_str, ",".join(args_str))
|
|
|
|
|
new_pop_node = gast.parse(new_pop_str).body[0].value
|
|
|
|
|
return new_pop_node
|
|
|
|
|
else:
|
|
|
|
|
idx_str = "None"
|
|
|
|
|
|
|
|
|
|
new_call_str = "fluid.dygraph.dygraph_to_static.list_transformer.convert_list_pop({}, {})".format(
|
|
|
|
|
target_str, idx_str)
|
|
|
|
|
new_call_node = gast.parse(new_call_str).body[0].value
|
|
|
|
|
return new_call_node
|
|
|
|
|
return node
|
|
|
|
|