@ -14,10 +14,96 @@
from __future__ import print_function
import gast
import astor
import gast
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import AstNodeWrapper , NodeVarType , StaticAnalysisVisitor
from paddle . fluid . dygraph . dygraph_to_static . utils import is_control_flow_to_transform , ast_to_source_code
from paddle . fluid . dygraph . dygraph_to_static . utils import ast_to_source_code , is_control_flow_to_transform
from paddle . fluid . framework import core , default_main_program , Variable
from paddle . fluid . layers import array_length , array_read , array_write , create_array
from paddle . fluid . layers import assign , cast , fill_constant , slice
from paddle . fluid . layers . control_flow import cond , while_loop , less_than , increment
__all__ = [ ' convert_list_pop ' ]
def create_array_in_parent_blcok ( null_array ) :
# TODO(liym27): Create a null tensor_array with the same name in parent block to avoid a bug in control flow,
# because in `null_array = create_array("float32")`, `null_array` is not a output of a real OP.
# See class ConditionalBlock for details.
prog = default_main_program ( )
parent_idx = prog . current_block ( ) . parent_idx
while parent_idx != - 1 :
parent_block = prog . block ( parent_idx )
parent_block . create_var (
name = null_array . name ,
type = core . VarDesc . VarType . LOD_TENSOR_ARRAY ,
dtype = " float32 " )
parent_idx = parent_block . parent_idx
# TODO(liym27): A better way to slice tensor array.
# Maybe support start == end for slice op.
def slice_tensor_array ( array , start , end ) :
end = cast ( end , " int32 " )
def true_fn ( ) :
null_array = create_array ( " float32 " )
create_array_in_parent_blcok ( null_array )
return null_array
def false_fn ( array , start , end ) :
new_array = slice ( array , starts = [ start ] , ends = [ end ] , axes = [ 0 ] )
return new_array
new_array = cond ( start == end , true_fn , lambda : false_fn ( array , start , end ) )
return new_array
def tensor_array_pop ( array , idx ) :
assert isinstance ( idx , int )
def cond ( i , new_array ) :
return less_than ( i , arr_len )
def body ( i , new_array ) :
item = array_read ( array = array , i = i )
array_write ( item , array_length ( new_array ) , new_array )
i = increment ( i )
return i , new_array
arr_len = array_length ( array )
if idx < 0 :
idx = idx + arr_len
else :
idx = fill_constant ( shape = [ 1 ] , dtype = " int64 " , value = idx )
pop_item = array_read ( array , idx )
new_array = slice_tensor_array ( array , 0 , idx )
i = idx + 1
_ , new_array = while_loop ( cond , body , [ i , new_array ] )
assign ( input = new_array , output = array )
return pop_item
def convert_list_pop ( target , idx = None ) :
"""
Convert list pop .
"""
if idx is None :
idx = - 1
is_variable = isinstance ( target , Variable )
if is_variable :
is_tensor_array = target . type == core . VarDesc . VarType . LOD_TENSOR_ARRAY
if is_variable and is_tensor_array :
result = tensor_array_pop ( target , idx )
else :
result = target . pop ( idx )
return result
class ListTransformer ( gast . NodeTransformer ) :
@ -45,12 +131,21 @@ class ListTransformer(gast.NodeTransformer):
self . visit ( self . root )
self . replace_list_with_tensor_array ( self . root )
def visit_Call ( self , node ) :
if isinstance ( node . func , gast . Attribute ) :
func_name = node . func . attr
if func_name == " pop " :
node = self . _replace_list_pop ( node )
return node
def visit_Assign ( self , node ) :
if self . _update_list_name_to_updated ( node ) :
return node
if self . _need_to_array_write_node ( node ) :
return self . _transform_slice_to_tensor_write ( node )
self . generic_visit ( node )
return node
def visit_If ( self , node ) :
@ -203,3 +298,21 @@ class ListTransformer(gast.NodeTransformer):
self . list_name_to_updated [ target_id ] == False :
del self . list_name_to_updated [ target_id ]
return False
def _replace_list_pop ( self , node ) :
assert isinstance ( node , gast . Call )
assert isinstance ( node . func , gast . Attribute )
target_node = node . func . value
target_str = ast_to_source_code ( target_node ) . strip ( )
if node . args :
idx_node = node . args [ 0 ]
idx_str = ast_to_source_code ( idx_node ) . strip ( )
else :
idx_str = " None "
new_call_str = " fluid.dygraph.dygraph_to_static.convert_list_pop( {} , {} ) " . format (
target_str , idx_str )
new_call_node = gast . parse ( new_call_str ) . body [ 0 ] . value
return new_call_node