|
|
|
@ -19,6 +19,7 @@ import gast
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
|
|
|
|
@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node,
|
|
|
|
|
|
|
|
|
|
if isinstance(var_shape_node, gast.Attribute):
|
|
|
|
|
args = [ast_to_source_code(var_shape_node.value).strip()]
|
|
|
|
|
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
|
|
|
|
|
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
|
|
|
|
|
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
|
|
|
|
|
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
|
|
|
|
|
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
|
|
|
|
|
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
|
|
|
|
|
if slice_node is not None and isinstance(slice_node, gast.Index):
|
|
|
|
|
args.append(ast_to_source_code(slice_node).strip())
|
|
|
|
|
if slice_node is not None and slice_is_num(slice_node):
|
|
|
|
|
args.append(ast_to_source_code(slice_node.slice).strip())
|
|
|
|
|
|
|
|
|
|
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
|
|
|
|
|
",".join(args), in_control_flow)
|
|
|
|
|
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
|
|
|
|
|
|
|
|
|
|
if slice_node is not None and not isinstance(slice_node, gast.Index):
|
|
|
|
|
if slice_node is not None and not slice_is_num(slice_node):
|
|
|
|
|
return gast.Subscript(
|
|
|
|
|
value=api_shape_node, slice=slice_node, ctx=gast.Load())
|
|
|
|
|
value=api_shape_node, slice=slice_node.slice, ctx=gast.Load())
|
|
|
|
|
return api_shape_node
|
|
|
|
|
|
|
|
|
|
if isinstance(var_shape_node, gast.Subscript):
|
|
|
|
|
result_node = copy.deepcopy(var_shape_node)
|
|
|
|
|
result_node = create_convert_shape_node(
|
|
|
|
|
result_node.value, result_node.slice, in_control_flow)
|
|
|
|
|
result_node = create_convert_shape_node(result_node.value, result_node,
|
|
|
|
|
in_control_flow)
|
|
|
|
|
return result_node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
|
|
|
|
|
# Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly.
|
|
|
|
|
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
|
|
|
|
|
api_shape_name)
|
|
|
|
|
args = [attr_shape_name, eval_exist_func]
|
|
|
|
|
|
|
|
|
|
if slice_node is not None and isinstance(slice_node, gast.Index):
|
|
|
|
|
args.append(ast_to_source_code(slice_node).strip())
|
|
|
|
|
if slice_node is not None and slice_is_num(slice_node):
|
|
|
|
|
args.append(ast_to_source_code(slice_node.slice).strip())
|
|
|
|
|
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
|
|
|
|
|
",".join(args))
|
|
|
|
|
choose_shape_node = gast.parse(choose_shape_func).body[0].value
|
|
|
|
|
if slice_node is not None and not isinstance(slice_node, gast.Index):
|
|
|
|
|
if slice_node is not None and not slice_is_num(slice_node):
|
|
|
|
|
return gast.Subscript(
|
|
|
|
|
value=choose_shape_node, slice=slice_node, ctx=gast.Load())
|
|
|
|
|
value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load())
|
|
|
|
|
return choose_shape_node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
|
|
|
|
|
value_node):
|
|
|
|
|
return create_choose_shape_node(
|
|
|
|
|
value_node.id, self.name_to_var_shape[value_node.id],
|
|
|
|
|
slice_node)
|
|
|
|
|
value_node.id, self.name_to_var_shape[value_node.id], node)
|
|
|
|
|
elif isinstance(value_node, gast.Attribute):
|
|
|
|
|
if self._used_by_paddle_api(value_node):
|
|
|
|
|
value_name = ast_to_source_code(value_node).strip()
|
|
|
|
|
if value_name in self.name_to_var_shape:
|
|
|
|
|
return create_choose_shape_node(
|
|
|
|
|
value_name, self.name_to_var_shape[value_name],
|
|
|
|
|
slice_node)
|
|
|
|
|
value_name, self.name_to_var_shape[value_name], node)
|
|
|
|
|
if self._is_var_shape(value_node):
|
|
|
|
|
return create_convert_shape_node(value_node, slice_node)
|
|
|
|
|
return create_convert_shape_node(value_node, node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_Attribute(self, node):
|
|
|
|
@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
|
|
|
|
|
static_shape_value_name = self.name_to_var_shape[
|
|
|
|
|
value_node.id]
|
|
|
|
|
static_shape_value_node = gast.parse(
|
|
|
|
|
static_shape_value_name).body[0].value
|
|
|
|
|
index_value_node = gast.Constant(value=idx, kind=None)
|
|
|
|
|
slice_index_node = gast.Index(value=index_value_node)
|
|
|
|
|
sub_node = gast.Subscript(
|
|
|
|
|
value=static_shape_value_node,
|
|
|
|
|
slice=slice_index_node,
|
|
|
|
|
ctx=gast.Load())
|
|
|
|
|
|
|
|
|
|
sub_node_str = "{}[{}]".format(static_shape_value_name,
|
|
|
|
|
idx)
|
|
|
|
|
sub_node = gast.parse(sub_node_str).body[0].value
|
|
|
|
|
|
|
|
|
|
update_static_shape_var_node.append(
|
|
|
|
|
gast.Assign(
|
|
|
|
@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
# x.shape becomes convert_var_shape_simple(x)
|
|
|
|
|
static_shape_value_node = ShapeAttributeTransformer(
|
|
|
|
|
).visit(static_shape_value_node)
|
|
|
|
|
index_value_node = gast.Constant(value=idx, kind=None)
|
|
|
|
|
slice_index_node = gast.Index(value=index_value_node)
|
|
|
|
|
sub_node = gast.Subscript(
|
|
|
|
|
value=static_shape_value_node,
|
|
|
|
|
slice=slice_index_node,
|
|
|
|
|
ctx=gast.Load())
|
|
|
|
|
|
|
|
|
|
sub_node_str = "{}[{}]".format(
|
|
|
|
|
ast_to_source_code(static_shape_value_node).strip(),
|
|
|
|
|
idx)
|
|
|
|
|
sub_node = gast.parse(sub_node_str).body[0].value
|
|
|
|
|
|
|
|
|
|
update_static_shape_var_node.append(
|
|
|
|
|
gast.Assign(
|
|
|
|
|