|
|
|
@ -16,9 +16,7 @@ import astor
|
|
|
|
|
import gast
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static import utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicApiTransformer(gast.NodeTransformer):
|
|
|
|
@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer):
|
|
|
|
|
if isinstance(child_node, gast.Call):
|
|
|
|
|
# TODO(liym27):
|
|
|
|
|
# Considers that a dygraph api which modifies the input or has a output.
|
|
|
|
|
if is_dygraph_api(child_node):
|
|
|
|
|
if utils.is_dygraph_api(child_node):
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
self._visit_Call(child_node)
|
|
|
|
@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer):
|
|
|
|
|
|
|
|
|
|
if self._is_dygraph_forward(func_name):
|
|
|
|
|
class_node = self._get_class_node(func_name)
|
|
|
|
|
static_node = to_static_ast(node, class_node)
|
|
|
|
|
static_node = utils.to_static_ast(node, class_node)
|
|
|
|
|
return static_node
|
|
|
|
|
else:
|
|
|
|
|
return node
|
|
|
|
@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer):
|
|
|
|
|
if is_to_variable(node_value):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if is_dygraph_api(node_value):
|
|
|
|
|
if utils.is_dygraph_api(node_value):
|
|
|
|
|
dygraph_api = node_value.func.attr
|
|
|
|
|
if not dygraph_class_to_static_api.get(dygraph_api):
|
|
|
|
|
if not utils.dygraph_class_to_static_api.get(dygraph_api):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
update_args_of_func(node_value, node_value, "__init__")
|
|
|
|
|
utils.update_args_of_func(node_value, node_value, "__init__")
|
|
|
|
|
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
|
|
|
|
|
self.class_node_dict[target_str] = node_value
|
|
|
|
|
return True
|
|
|
|
|
# TODO: node.value is not dygraph class
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_to_variable(node):
|
|
|
|
|
assert isinstance(node, gast.Call)
|
|
|
|
|
api_name = utils.ast_to_source_code(node.func).strip()
|
|
|
|
|
|
|
|
|
|
if utils.is_dygraph_api(node):
|
|
|
|
|
return api_name.endswith("to_variable")
|
|
|
|
|
|
|
|
|
|
if utils.is_paddle_api(node):
|
|
|
|
|
return api_name.endswith("to_tensor")
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_assign_node(node):
|
|
|
|
|
# Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`.
|
|
|
|
|
# NOTE:
|
|
|
|
|
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
|
|
|
|
|
# but api `assign` only supports {float32, float64, int32, int64, bool};
|
|
|
|
|
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
|
|
|
|
|
|
|
|
|
|
assert isinstance(node, gast.Call)
|
|
|
|
|
assign_api = gast.parse('fluid.layers.assign').body[0].value
|
|
|
|
|
node.func = assign_api
|
|
|
|
|
|
|
|
|
|
if node.args:
|
|
|
|
|
node.args = [node.args[0]]
|
|
|
|
|
node.keywords = []
|
|
|
|
|
else:
|
|
|
|
|
for idx, kw in enumerate(node.keywords):
|
|
|
|
|
if kw.arg == 'value' or kw.arg == 'data':
|
|
|
|
|
node.keywords[idx].arg = 'input'
|
|
|
|
|
node.keywords = [node.keywords[idx]]
|
|
|
|
|
node.args = []
|
|
|
|
|
break
|
|
|
|
|
return node
|
|
|
|
|