@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import core
import framework
import executor
@ -20,6 +21,7 @@ import contextlib
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
import distribute_transpiler
__all__ = [
' Trainer ' ,
@ -76,22 +78,61 @@ class Trainer(object):
raise TypeError (
" The optimizer should be an instance of Optimizer " )
optimize r. minimize ( loss )
optimize _ops, params_grads = optimize r. minimize ( loss )
self . place = Trainer . _check_and_get_place ( place )
self . dist_transpile_if_necessary ( optimize_ops , params_grads )
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
# Run startup program
with self . _prog_and_scope_guard ( ) :
exe = executor . Executor ( place )
exe . run ( self . startup_program , scope = self . scope )
exe . run ( self . startup_program )
if param_path :
# load params from param_path into scope
# TODO(yuyang): This depends on parameters implementation.
pass
# TODO(helin): support distributed training
def dist_transpile_if_necessary ( self , optimize_ops , params_grads ) :
if " PADDLE_TRAINING_ROLE " not in os . environ :
return
# the port of all pservers, needed by both trainer and pserver
port = os . getenv ( " PADDLE_PSERVER_PORT " , " 6174 " )
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips = os . getenv ( " PADDLE_PSERVER_IPS " , " " )
eplist = [ ]
for ip in pserver_ips . split ( " , " ) :
eplist . append ( ' : ' . join ( [ ip , port ] ) )
pserver_endpoints = " , " . join ( eplist )
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers = int ( os . getenv ( " PADDLE_TRAINERS " ) )
# the IP of the local machine, needed by pserver only
current_endpoint = os . getenv ( " PADDLE_CURRENT_IP " , " " ) + " : " + port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int ( os . getenv ( " PADDLE_TRAINER_ID " , " 0 " ) )
# the role, should be either PSERVER or TRAINER
training_role = os . getenv ( " PADDLE_TRAINING_ROLE " )
with self . _prog_and_scope_guard ( ) :
t = distribute_transpiler . DistributeTranspiler ( )
t . transpile (
trainer_id , pservers = pserver_endpoints , trainers = trainers )
if training_role == " PSERVER " :
self . train_program = t . get_pserver_program ( current_endpoint )
self . startup_program = t . get_startup_program ( current_endpoint ,
self . train_program )
elif training_role == " TRAINER " :
self . train_program = t . get_trainer_program ( )
else :
raise ValueError (
' TRAINING_ROLE environment variable must be either TRAINER or PSERVER '
)
def train ( self ,
num_epochs ,
@ -117,6 +158,13 @@ class Trainer(object):
raise NotImplementedError (
" Parallel Executor version of trainer is not implemented " )
training_role = os . getenv ( " PADDLE_TRAINING_ROLE " , " " )
if training_role == " PSERVER " :
with self . _prog_and_scope_guard ( ) :
exe = executor . Executor ( self . place )
exe . run ( )
return
self . _train_by_executor ( num_epochs , event_handler , reader , feed_order )
def test ( self , reader ) :