[dy2static]Fix a bug of is_dygraph_api and move BasicApiTransformer to a separate file(#23923)
* Move BasicApiTransformer to a separate file. test=develop * Fix a bug: A api in module is not a real dygraph api in dygraph_to_static. test=developrevert-22778-infer_var_type
parent
c645d23519
commit
37ef7c1351
@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import astor
|
||||
import gast
|
||||
|
||||
from paddle.fluid import unique_name
|
||||
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
|
||||
|
||||
|
||||
class BasicApiTransformer(gast.NodeTransformer):
|
||||
"""
|
||||
Class to transform basic API from dygraph to static graph.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapper_root):
|
||||
assert isinstance(
|
||||
wrapper_root, AstNodeWrapper
|
||||
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
|
||||
|
||||
self.wrapper_root = wrapper_root
|
||||
self.root = wrapper_root.node
|
||||
self.class_node_dict = {}
|
||||
|
||||
# Used for transformation of data feed
|
||||
self.feed_name_to_arg_id = {}
|
||||
self.name_to_tensor_shape = {}
|
||||
|
||||
def transform(self):
|
||||
self.visit(self.root)
|
||||
return self.wrapper_root
|
||||
|
||||
def visit_Assign(self, node):
|
||||
if self._update_class_node_dict(node):
|
||||
return None
|
||||
|
||||
for child_node in gast.walk(node.value):
|
||||
if isinstance(child_node, gast.Call):
|
||||
self._visit_Call(child_node)
|
||||
return node
|
||||
|
||||
def visit_Expr(self, node):
|
||||
value_node = node.value
|
||||
for child_node in gast.walk(value_node):
|
||||
if isinstance(child_node, gast.Call):
|
||||
# TODO(liym27):
|
||||
# Considers that a dygraph api which modifies the input or has a output.
|
||||
if is_dygraph_api(child_node):
|
||||
return
|
||||
else:
|
||||
self._visit_Call(child_node)
|
||||
return node
|
||||
|
||||
def _visit_Call(self, node):
|
||||
assert isinstance(node, gast.Call)
|
||||
# Replace API `to_variable` with `fluid.layers.assign`
|
||||
if is_to_variable(node):
|
||||
self._update_feed_dict(node)
|
||||
node = to_assign_node(node)
|
||||
return node
|
||||
|
||||
func_name = astor.to_source(gast.gast_to_ast(node.func))
|
||||
|
||||
if self._is_dygraph_forward(func_name):
|
||||
class_node = self._get_class_node(func_name)
|
||||
static_node = to_static_ast(node, class_node)
|
||||
return static_node
|
||||
else:
|
||||
return node
|
||||
|
||||
def _is_dygraph_forward(self, func_id):
|
||||
return func_id in self.class_node_dict
|
||||
|
||||
def _get_class_node(self, func_id):
|
||||
return self.class_node_dict[func_id]
|
||||
|
||||
def _update_class_node_dict(self, node):
|
||||
assert isinstance(node, gast.Assign)
|
||||
node_value = node.value
|
||||
if isinstance(node_value, gast.Call):
|
||||
if is_to_variable(node_value):
|
||||
return False
|
||||
|
||||
if is_dygraph_api(node_value):
|
||||
dygraph_api = node_value.func.attr
|
||||
if not dygraph_class_to_static_api.get(dygraph_api):
|
||||
return False
|
||||
|
||||
update_args_of_func(node_value, node_value, "__init__")
|
||||
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
|
||||
self.class_node_dict[target_str] = node_value
|
||||
return True
|
||||
# TODO: node.value is not dygraph class
|
||||
return False
|
||||
|
||||
def _update_feed_dict(self, node):
|
||||
assert isinstance(node, gast.Call)
|
||||
|
||||
value_node = None
|
||||
for kw in node.keywords:
|
||||
if kw.arg == 'value':
|
||||
value_node = kw.value # eg: `a` for "value=a "
|
||||
if not value_node:
|
||||
value_node = node.args[0]
|
||||
|
||||
if not isinstance(value_node, gast.Name):
|
||||
return
|
||||
else:
|
||||
var_name = value_node.id
|
||||
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
|
||||
self.feed_name_to_arg_id[
|
||||
feed_var_name] = var_name # eg: "a_0" : "a"
|
||||
|
||||
def get_feed_name_to_arg_id(self):
|
||||
return self.feed_name_to_arg_id
|
Loading…
Reference in new issue