@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +15,10 @@
from __future__ import print_function
import gast
import inspect
import logging
import numpy
import textwrap
import threading
import warnings
from paddle . fluid import framework
from paddle . fluid import core , executor
@ -32,6 +32,8 @@ from paddle.fluid.data_feeder import check_type
__all__ = [ ' ProgramTranslator ' , ' convert_function_with_cache ' ]
logger = logging . getLogger ( " fluid " )
class FunctionCache ( object ) :
"""
@ -235,6 +237,10 @@ class ProgramCache(object):
class ProgramTranslator ( object ) :
"""
Class to translate dygraph function into static graph function .
"""
_singleton_lock = threading . Lock ( )
_instance = None
@ -274,16 +280,37 @@ class ProgramTranslator(object):
self . _loss_name = None
# Once startup_program is changed, should run startup_program.
self . _prev_startup = None
self . enable_declarative = True
def enable_declarative_function ( self , enable_declarative ) :
"""
Enable or disable the converting from imperative to declarative by
ProgramTranslator globally .
Args :
enable_declarative ( bool ) : True or False to enable or disable declarative
"""
self . enable_declarative = enable_declarative
def get_output ( self , dygraph_func , * args , * * kwargs ) :
"""
Returns the output tensors for dygraph function and its arguments
Returns the output dygraph VarBase for dygraph function . The dygraph
function will be translated into static graph function so the under
beneath numerical result will be calculated by declarative mode .
Args :
dygraph_func ( callable ) : the dygraph function .
* args , * * kwargs : the input argument of dygraph_func .
Returns :
VarBase or tuple of VarBase : the dygraph VarBase containing digital
result .
"""
if in_dygraph_mode ( ) :
warnings . warn (
if in_dygraph_mode ( ) or not self . enable_declarative :
logger. info (
" The ProgramTranslator.get_output doesn ' t work in dygraph "
" mode. We will just return dygraph output. Use it in "
" static mode if you would like to translate to static graph. " )
" mode or set enable_declarative_function to False. We will "
" just return dygraph output ." )
return dygraph_func ( * args , * * kwargs )
program_cache = self . get_program_cache ( )
@ -292,33 +319,60 @@ class ProgramTranslator(object):
if not program_cache . in_build_process :
outputs = self . run ( * args , * * kwargs )
with guard ( ) :
outputs = [ to_variable ( x ) for x in outputs ]
if len ( outputs ) == 1 :
outputs = to_variable ( outputs [ 0 ] )
else :
outputs = tuple ( to_variable ( x ) for x in outputs )
return outputs
def get_func ( self , dygraph_func ) :
"""
Returns the translated static function from dygraph function
Returns a callable function which converts imperative dygraph APIs of
the input dygraph_func into declarative net - building APIs , which means
it doesn ' t return immediate digital result as get_output does.
Users should handle Program and Executor by themselves .
Args :
dygraph_func ( callable ) : the dygraph function .
Returns :
callable : converting imperative dygraph APIs into declarative
net - building APIs .
"""
if in_dygraph_mode ( ) :
warnings . warn (
if in_dygraph_mode ( ) or not self . enable_declarative :
logger. info (
" The ProgramTranslator.get_func doesn ' t work in dygraph "
" mode. We will just return dygraph function. Use it in "
" static mode if you would like to translate to static graph. " )
" mode or set enable_declarative_function to False. We will "
" just return dygraph output ." )
return dygraph_func
static_func = convert_function_with_cache ( dygraph_func )
return static_func
def get_program ( self , dygraph_func , * args , * * kwargs ) :
"""
Returns the translated static program and input / output variables from
dygraph function .
"""
if in_dygraph_mode ( ) :
warnings . warn (
dygraph function . The users can use the program to run by executor .
Args :
dygraph_func ( callable ) : the dygraph function .
* args , * * kwargs : the input argument of dygraph_func .
Returns :
tuple of ( main_program , startup_program , inputs , outputs ) whose
types are ( Program , Program , list of Variable , list of Variable ) .
main_program : the converted main program .
startup_program : the converted startup program .
inputs : list of input Variables which need to be fed .
outputs : list of output Variables which users can fetch .
"""
if in_dygraph_mode ( ) or not self . enable_declarative :
logger . info (
" The ProgramTranslator.get_program doesn ' t work in dygraph "
" mode. We will just return dygraph output. Use it in static "
" mode if you would like to translate to static graph. " )
" mode or set enable_declarative_function to False. We will "
" just return dygraph output ." )
return dygraph_func ( * args , * * kwargs )
program_cache = self . get_program_cache ( )
outputs = program_cache . build_program_and_return_output ( dygraph_func ,
* args , * * kwargs )
@ -326,7 +380,13 @@ class ProgramTranslator(object):
def get_code ( self , dygraph_func ) :
"""
Returns the translated static function code from dygraph code
Returns the translated static function string code from dygraph function .
Args :
dygraph_func ( callable ) : the dygraph function .
Returns :
str : the string code of translated static function
"""
# Gets AST from dygraph function
raw_code = inspect . getsource ( dygraph_func )