Tensor.shape support control flow if/for/while and bugfix (#22866)
* Support Tensor.shape in control flow if/for/while and separate TensorShapeTransformer from BasicApiTransformer. test=developrevert-22710-feature/integrated_ps_api
parent
714b0076b6
commit
4af491c2bb
@ -0,0 +1,210 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
import astor
|
||||
import copy
|
||||
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
|
||||
from paddle.fluid import unique_name
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
|
||||
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
|
||||
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
|
||||
|
||||
|
||||
class TensorShapeTransformer(gast.NodeTransformer):
|
||||
"""
|
||||
This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapper_root):
|
||||
assert isinstance(
|
||||
wrapper_root, AstNodeWrapper
|
||||
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
|
||||
self.wrapper_root = wrapper_root
|
||||
self.root = wrapper_root.node
|
||||
self.name_to_tensor_shape = {}
|
||||
|
||||
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
|
||||
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
|
||||
)
|
||||
var_env = self.static_analysis_visitor.get_var_env()
|
||||
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
|
||||
self.scope_var_type_dict = var_env.get_scope_var_type()
|
||||
|
||||
def transform(self):
|
||||
self.visit(self.root)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
if self._update_name_to_tensor_shape(node):
|
||||
return node
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
if self._used_by_paddle_api(node):
|
||||
if self.is_tensor_shape(node):
|
||||
return create_api_shape_node(node)
|
||||
return node
|
||||
|
||||
def visit_Name(self, node):
|
||||
if node.id in self.name_to_tensor_shape:
|
||||
if self._used_by_paddle_api(node):
|
||||
tensor_shape_node = self.name_to_tensor_shape[node.id]
|
||||
return create_api_shape_node(tensor_shape_node)
|
||||
return node
|
||||
|
||||
def visit_Call(self, node):
|
||||
assert isinstance(node, gast.Call)
|
||||
if is_paddle_api(node):
|
||||
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary.
|
||||
self.generic_visit(node)
|
||||
|
||||
return node
|
||||
|
||||
def visit_If(self, node):
|
||||
# Call generic_visit first to transform Tensor.shape that is used in Paddle Api.
|
||||
self.generic_visit(node)
|
||||
cond = node.test
|
||||
self._transform_tensor_shape_if_necessary(cond)
|
||||
return node
|
||||
|
||||
def visit_While(self, node):
|
||||
self.generic_visit(node)
|
||||
cond = node.test
|
||||
self._transform_tensor_shape_if_necessary(cond)
|
||||
return node
|
||||
|
||||
def visit_For(self, node):
|
||||
self.generic_visit(node)
|
||||
iter = node.iter
|
||||
self._transform_tensor_shape_if_necessary(iter)
|
||||
|
||||
# If tensor.shape is a gast.Name and it is used in range function, transform it
|
||||
self._transform_tensor_shape_in_range(node)
|
||||
return node
|
||||
|
||||
def _transform_tensor_shape_in_range(self, node):
|
||||
assert isinstance(node, gast.For)
|
||||
if not isinstance(node.iter, gast.Call):
|
||||
return False
|
||||
if not isinstance(node.iter.func, gast.Name):
|
||||
return False
|
||||
if node.iter.func.id != "range":
|
||||
return False
|
||||
args = node.iter.args
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg,
|
||||
gast.Name) and arg.id in self.name_to_tensor_shape:
|
||||
args[idx] = create_api_shape_node(self.name_to_tensor_shape[
|
||||
arg.id])
|
||||
|
||||
return True
|
||||
|
||||
def _transform_tensor_shape_if_necessary(self, cond):
|
||||
for child_node in gast.walk(cond):
|
||||
tensor_shape_node = None
|
||||
if isinstance(child_node, (gast.Attribute)):
|
||||
if self.is_tensor_shape(child_node):
|
||||
tensor_shape_node = child_node
|
||||
elif isinstance(child_node, (gast.Name)):
|
||||
if child_node.id in self.name_to_tensor_shape:
|
||||
tensor_shape_node = self.name_to_tensor_shape[child_node.id]
|
||||
|
||||
if tensor_shape_node:
|
||||
wrapper_node = self.node_to_wrapper_map.get(child_node)
|
||||
parent_node = wrapper_node.parent.node
|
||||
for field, value in gast.iter_fields(parent_node):
|
||||
if child_node is value:
|
||||
setattr(parent_node, field,
|
||||
create_api_shape_node(tensor_shape_node))
|
||||
break
|
||||
|
||||
def _used_by_paddle_api(self, node):
|
||||
assert isinstance(node, (gast.Attribute, gast.Name))
|
||||
wrapper_node = self.node_to_wrapper_map.get(node)
|
||||
if not wrapper_node:
|
||||
# Transformed node is not in node_to_wrapper_map
|
||||
return False
|
||||
while wrapper_node.parent:
|
||||
parent_node = wrapper_node.parent.node
|
||||
if isinstance(parent_node, gast.Call):
|
||||
if is_paddle_api(parent_node):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
wrapper_node = wrapper_node.parent
|
||||
|
||||
return False
|
||||
|
||||
def is_tensor_shape(self, node):
|
||||
"""
|
||||
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
|
||||
"""
|
||||
assert isinstance(node, gast.Attribute)
|
||||
if node.attr != 'shape':
|
||||
return False
|
||||
|
||||
try:
|
||||
value_id = node.value.id
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if value_id in self.name_to_tensor_shape:
|
||||
return True
|
||||
|
||||
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
|
||||
# Need a better way to confirm whether `value_id` is a Tensor.
|
||||
try:
|
||||
var_type_set = self.scope_var_type_dict[value_id]
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
if NodeVarType.NUMPY_NDARRAY in var_type_set:
|
||||
return False
|
||||
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _update_name_to_tensor_shape(self, node):
|
||||
assert isinstance(node, gast.Assign)
|
||||
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
|
||||
target_node = node.targets[0]
|
||||
try:
|
||||
target_id = target_node.id
|
||||
except AttributeError:
|
||||
return False
|
||||
value_node = node.value
|
||||
|
||||
if isinstance(value_node, gast.Name):
|
||||
if value_node.id in self.name_to_tensor_shape:
|
||||
self.name_to_tensor_shape[
|
||||
target_id] = self.name_to_tensor_shape[value_node.id]
|
||||
return True
|
||||
if isinstance(value_node, gast.Attribute):
|
||||
if self.is_tensor_shape(value_node): # eg: x.shape
|
||||
self.name_to_tensor_shape[target_id] = value_node
|
||||
return True
|
||||
if isinstance(value_node, gast.Subscript):
|
||||
if isinstance(value_node.value, gast.Attribute):
|
||||
if self.is_tensor_shape(value_node.value): # eg: x.shape[0]
|
||||
self.name_to_tensor_shape[target_id] = value_node
|
||||
return True
|
||||
return False
|
Loading…
Reference in new issue