@ -15,11 +15,15 @@
import threading
import time
import logging
import numpy as np
logging . basicConfig ( )
from . trainer_desc import MultiTrainer , DistMultiTrainer , PipelineTrainer
from . device_worker import Hogwild , DownpourSGD , Section
from . framework import Variable
from multiprocessing import Process , Manager
__all__ = [ " TrainerFactory " , " FetchHandler " , " FetchHandlerMonitor " ]
@ -93,68 +97,74 @@ class FetchHandlerMonitor(object):
def __init__ ( self , scope , handler ) :
self . fetch_instance = handler
self . fetch_thread = threading . Thread (
target = self . handler_ decorator,
args = ( scope , self . fetch_instance . handler ) )
target = self . handler_ launch_func, args = ( scope , self . fetch_instance ) )
self . running_lock = threading . Lock ( )
self . running = False
def start ( self ) :
"""
start monitor ,
it will start a monitor thread .
"""
self . running = True
self . fetch_thread . setDaemon ( True )
self . fetch_thread . start ( )
def handler_decorator ( self , fetch_scope , fetch_handler ) :
"""
decorator of handler ,
Args :
fetch_scope ( Scope ) : fetch scope
fetch_handler ( Handler ) : fetch handler
"""
fetch_target_names = self . fetch_instance . fetch_target_names
period_secs = self . fetch_instance . period_secs
def handler_launch_func ( self , scope , handler ) :
fetch_instance = handler
period_secs = fetch_instance . period_secs
var_name_to_key = { }
for key in fetch_instance . var_dict :
if isinstance ( fetch_instance . var_dict [ key ] , Variable ) :
var_name_to_key [ fetch_instance . var_dict [ key ] . name ] = key
else :
logging . warning ( " the value of {} is not a Variable " . format ( key ) )
var_name_to_key [ " None.var " ] = key
elapsed_secs = 0
while True :
while self . running and elapsed_secs > = period_secs :
self . running_lock . acquire ( )
if self . running == False :
break
if elapsed_secs < period_secs :
# TODO(guru4elephant): needs customized condition
time . sleep ( 1 )
elapsed_secs + = 1
else :
elapsed_secs = 0
fetch_vars = [
fetch_scope . find_var ( varname )
for varname in fetch_target_names
]
if None in fetch_vars :
fetch_dict = { }
for key in var_name_to_key :
var = scope . find_var ( key )
fetch_dict [ key ] = var
if var == None :
logging . warning ( " {} value currently not available " .
format ( var_name_to_key [ key ] ) )
res_dict = { }
for key in fetch_dict :
user_name = var_name_to_key [ key ]
if fetch_dict [ key ] == None :
res_dict [ user_name ] = None
continue
else :
res_dict [ user_name ] = fetch_dict [ key ] . get_tensor ( )
fetch_tensors = [ var . get_tensor ( ) for var in fetch_vars ]
if self . fetch_instance . return_np :
fetch_nps = [ ]
for tensor in fetch_tensors :
lod = tensor . lod ( )
lod = res_dict [ user_name ] . lod ( )
if len ( lod ) > 0 :
raise RuntimeError (
" Some of your fetched tensors hold LoD information. \
They can not be completely cas t to Python ndarray . We can no t \
return LoDTensor itself directly , please choose another targets "
)
if tenso r. _is_initialized ( ) :
fetch_nps . append ( np . array ( tensor ) )
raise RuntimeError ( " Some of your fetched tensors \
hold LoD information . \
They can not be completely cast \
to Python ndarray . We can \
not return LoDTensor itself directly , \
please choose another targets " )
if res_dic t[ us er_name] . _is_initialized ( ) :
res_dict [ user_name ] = np . array ( res_dict [ user_name ] )
else :
fetch_nps . append ( None )
res_dict [ user_name ] = None
fetch_instance . handler ( res_dict )
self . running_lock . release ( )
fetch_handler ( fetch_nps )
else :
fetch_handler ( fetch_tensors )
else :
time . sleep ( 1 )
elapsed_secs + = 1
def start ( self ) :
"""
start monitor ,
it will start a monitor thread .
"""
self . running_lock . acquire ( )
self . running = True
self . running_lock . release ( )
self . fetch_thread . setDaemon ( True )
self . fetch_thread . start ( )
def stop ( self ) :
self . running_lock . acquire ( )
self . running = False
self . running_lock . release ( )