@ -3,10 +3,12 @@ import collections
import numpy as np
from . import core
import proto . framework_pb2 as framework_pb2
import contextlib
__all__ = [
' Block ' , ' Variable ' , ' Program ' , ' Operator ' , ' default_startup_program ' ,
' default_main_program '
' default_main_program ' , ' program_guard ' , ' switch_startup_program ' ,
' switch_main_program '
]
@ -659,8 +661,83 @@ _startup_program_ = Program()
def default_startup_program ( ) :
"""
Get default startup program . In startup program , Paddle will initialize
parameters , initialize nccl handle , etc .
Returns :
Program : startup program
"""
return _startup_program_
def default_main_program ( ) :
"""
Get default main program . The main program is used for training or testing .
Returns :
Program : main program
"""
return _main_program_
def switch_main_program ( program ) :
"""
Switch the main program to a new program .
Args :
program ( Program ) : The new main program
Returns :
Program : The previous main program
"""
global _main_program_
prev_program = _main_program_
_main_program_ = program
return prev_program
def switch_startup_program ( program ) :
"""
Switch the startup program to a new program
Args :
program ( Program ) : The new startup program
Returns :
Program : The previous startup program
"""
global _startup_program_
prev_program = _startup_program_
_startup_program_ = program
return prev_program
@contextlib.contextmanager
def program_guard ( main_program , startup_program = None ) :
"""
Switch program with ` with ` statement
Examples :
>> > with program_guard ( Program ( ) ) :
>> > data = fluid . layers . data ( . . . )
>> > hidden = fluid . layers . fc ( . . . )
Args :
main_program ( Program ) : New main program inside ` with ` statement
startup_program ( Program ) : New startup program inside ` with ` statement .
None means do not change startup program .
Returns :
None
"""
if not isinstance ( main_program , Program ) :
raise TypeError ( " main_program should be Program " )
main_program = switch_main_program ( main_program )
if startup_program is not None :
if not isinstance ( startup_program , Program ) :
raise TypeError ( " startup_program should be Program " )
startup_program = switch_startup_program ( startup_program )
yield
switch_main_program ( main_program )
if startup_program is not None :
switch_startup_program ( startup_program )