@ -22,6 +22,8 @@ import paddle.fluid.layers as layers
import paddle . fluid . optimizer as optimizer
from paddle . fluid . framework import Program , program_guard
import paddle . fluid . core as core
import paddle . fluid . compiler as compiler
import os
BATCH_SIZE = 1
INPUT_SIZE = 784
@ -106,7 +108,7 @@ def static(train_data,
place = fluid . CUDAPlace ( 0 ) if use_cuda else fluid . CPUPlace ( )
exe = fluid . Executor ( place )
exe . run ( fluid. default_ startup_program( ) )
exe . run ( startup_program)
for epoch in range ( EPOCH_NUM ) :
feed_image , feed_label = train_data [ epoch ]
@ -225,5 +227,58 @@ class TestMultiTask(unittest.TestCase):
loss_2 ) )
class TestMultiOptimizersMultiCardsError ( unittest . TestCase ) :
def test_error ( self ) :
startup_program = Program ( )
main_program = Program ( )
use_cuda = core . is_compiled_with_cuda ( )
with program_guard ( main_program , startup_program ) :
def fn_1 ( opt , avg_loss ) :
opt . minimize ( avg_loss )
def fn_2 ( opt , avg_loss ) :
opt . minimize ( avg_loss )
x = fluid . layers . data ( " X " , [ 10 ] , ' float32 ' )
hidden = layers . fc ( x , 5 )
avg_loss = layers . mean ( hidden )
adam = optimizer . Adam ( learning_rate = LR )
sgd = optimizer . SGD ( learning_rate = LR )
cond = layers . fill_constant ( [ 1 ] , ' bool ' , True )
layers . case ( [ ( cond , lambda : fn_1 ( adam , avg_loss ) ) ] ,
lambda : fn_2 ( sgd , avg_loss ) )
cpu_place = fluid . CPUPlace ( )
cuda_place = fluid . CUDAPlace ( 0 ) if use_cuda else fluid . CPUPlace ( )
for place in [ cpu_place , cuda_place ] :
exe = fluid . Executor ( place )
exe . run ( startup_program )
np . random . seed ( SEED )
os . environ [ ' CPU_NUM ' ] = str ( 2 )
pe_exe = fluid . ParallelExecutor (
use_cuda = use_cuda ,
main_program = main_program ,
loss_name = avg_loss . name )
num_devices = pe_exe . device_count
def not_implemented_error ( ) :
pe_exe . run ( feed = {
' X ' : np . random . random ( size = [ 64 , 10 ] ) . astype ( ' float32 ' ) ,
} ,
fetch_list = [ avg_loss . name ] )
if num_devices > 1 :
self . assertRaises ( NotImplementedError , not_implemented_error )
else :
not_implemented_error ( )
if __name__ == ' __main__ ' :
unittest . main ( )