|
|
|
@ -59,7 +59,7 @@ def create_convert_shape_node(var_shape_node,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
|
|
|
|
|
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
|
|
|
|
|
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
|
|
|
|
|
api_shape_name)
|
|
|
|
|
args = [attr_shape_name, eval_exist_func]
|
|
|
|
|
|
|
|
|
@ -293,6 +293,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _update_name_to_var_shape(self, node):
|
|
|
|
|
def replace_dot(name):
|
|
|
|
|
# replace all '.' into '_'
|
|
|
|
|
return name.replace('.', '_')
|
|
|
|
|
|
|
|
|
|
assert isinstance(node, gast.Assign)
|
|
|
|
|
target_node = node.targets[0]
|
|
|
|
|
value_node = node.value
|
|
|
|
@ -307,7 +311,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
if value_node.id in self.name_to_var_shape:
|
|
|
|
|
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
|
|
|
|
|
static_shape_var_name = unique_name.generate(
|
|
|
|
|
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
replace_dot(target_id) +
|
|
|
|
|
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
static_shape_var_node = gast.parse(
|
|
|
|
|
static_shape_var_name).body[0].value
|
|
|
|
|
|
|
|
|
@ -328,7 +333,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
if isinstance(value_node, gast.Attribute):
|
|
|
|
|
if self._is_var_shape(value_node): # eg: x.shape
|
|
|
|
|
static_shape_var_name = unique_name.generate(
|
|
|
|
|
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
replace_dot(target_id) +
|
|
|
|
|
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
static_shape_var_node = gast.parse(
|
|
|
|
|
static_shape_var_name).body[0].value
|
|
|
|
|
|
|
|
|
@ -341,6 +347,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
ast_to_source_code(static_shape_value_node).strip(),
|
|
|
|
|
idx)
|
|
|
|
|
sub_node = gast.parse(sub_node_str).body[0].value
|
|
|
|
|
# Note(Aurelius84): Becuase static_shape_var_name is used in
|
|
|
|
|
# eval_if_exist_else_none() as plain string, so it will not
|
|
|
|
|
# be pasred as argument in convert_loop/ifelse. We delcare it
|
|
|
|
|
# as global var because it has unique name.
|
|
|
|
|
update_static_shape_var_node.append(
|
|
|
|
|
gast.Global(names=[static_shape_var_name]))
|
|
|
|
|
|
|
|
|
|
update_static_shape_var_node.append(
|
|
|
|
|
gast.Assign(
|
|
|
|
@ -354,7 +366,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
if isinstance(value_node, gast.Name):
|
|
|
|
|
if value_node.id in self.name_to_var_shape:
|
|
|
|
|
static_shape_var_name = unique_name.generate(
|
|
|
|
|
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
replace_dot(target_id) +
|
|
|
|
|
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
static_shape_var_node = gast.parse(
|
|
|
|
|
static_shape_var_name).body[0].value
|
|
|
|
|
static_shape_value_name = self.name_to_var_shape[
|
|
|
|
@ -370,17 +383,20 @@ class TensorShapeTransformer(gast.NodeTransformer):
|
|
|
|
|
self.name_to_var_shape[target_id] = static_shape_var_name
|
|
|
|
|
elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
|
|
|
|
|
static_shape_var_name = unique_name.generate(
|
|
|
|
|
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
|
|
|
|
|
static_shape_var_node = gast.parse(static_shape_var_name).body[
|
|
|
|
|
0].value
|
|
|
|
|
static_shape_value_node = copy.deepcopy(value_node)
|
|
|
|
|
# x.shape becomes convert_var_shape_simple(x)
|
|
|
|
|
static_shape_value_node = ShapeAttributeTransformer().visit(
|
|
|
|
|
static_shape_value_node)
|
|
|
|
|
# Declare static_shape_var_name as global var
|
|
|
|
|
update_static_shape_var_node = [
|
|
|
|
|
gast.Global(names=[static_shape_var_name])
|
|
|
|
|
]
|
|
|
|
|
update_static_shape_var_node.append(
|
|
|
|
|
gast.Assign(
|
|
|
|
|
targets=[static_shape_var_node],
|
|
|
|
|
value=static_shape_value_node)
|
|
|
|
|
]
|
|
|
|
|
value=static_shape_value_node))
|
|
|
|
|
self.name_to_var_shape[target_id] = static_shape_var_name
|
|
|
|
|
return update_static_shape_var_node
|
|
|
|
|