|
|
|
@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2
|
|
|
|
|
from google.protobuf import text_format
|
|
|
|
|
from . import io
|
|
|
|
|
from .data_feed_desc import DataFeedDesc
|
|
|
|
|
from .trainer_desc import TrainerDesc, MultiTrainer, DistMultiTrainer
|
|
|
|
|
from .distributed import ps_instance
|
|
|
|
|
from .contrib.utils import hdfs_utils as hdfs
|
|
|
|
|
|
|
|
|
@ -89,6 +90,38 @@ class AsyncExecutor(object):
|
|
|
|
|
self.executor = core.AsyncExecutor(scope, p)
|
|
|
|
|
self.instance = None
|
|
|
|
|
|
|
|
|
|
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
|
|
|
|
|
if program is None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
program_desc = program.desc
|
|
|
|
|
|
|
|
|
|
if data_feed is None:
|
|
|
|
|
raise ValueError('ValueError: data_feed should be provided')
|
|
|
|
|
|
|
|
|
|
if filelist is None:
|
|
|
|
|
raise ValueError('ValueError: filelist should be provided')
|
|
|
|
|
|
|
|
|
|
if isinstance(filelist, str):
|
|
|
|
|
filelist = [filelist]
|
|
|
|
|
|
|
|
|
|
if not isinstance(thread_num, int):
|
|
|
|
|
raise TypeError('TypeError: thread_num should be a positive number')
|
|
|
|
|
|
|
|
|
|
is_local = self.instance == None
|
|
|
|
|
trainer = None
|
|
|
|
|
if is_local:
|
|
|
|
|
trainer = MultiTrainer(data_feed=data_feed, worker="Hogwild")
|
|
|
|
|
else:
|
|
|
|
|
trainer = DistMultiTrainer(
|
|
|
|
|
data_feed, worker="Downpour", fleet_desc=self.dist_desc)
|
|
|
|
|
|
|
|
|
|
# define a trainer and a device_worker here
|
|
|
|
|
trainer.set_thread(thread_num)
|
|
|
|
|
trainer.set_filelist(filelist)
|
|
|
|
|
trainer.set_data_feed(data_feed)
|
|
|
|
|
self.executor.run_from_files(program_desc, trainer._desc(), debug)
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def run(self,
|
|
|
|
|
program,
|
|
|
|
|
data_feed,
|
|
|
|
@ -160,6 +193,7 @@ class AsyncExecutor(object):
|
|
|
|
|
self.executor.run_from_files(program_desc,
|
|
|
|
|
data_feed.desc(), filelist, thread_num,
|
|
|
|
|
fetch_var_names, mode, debug)
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def download_data(self,
|
|
|
|
|
afs_path,
|
|
|
|
@ -250,6 +284,7 @@ class AsyncExecutor(object):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'instance is None, please run config_distributed_nodes init instance'
|
|
|
|
|
)
|
|
|
|
|
self.init_desc = init_desc
|
|
|
|
|
self.executor.init_server(dist_desc, self.instance._rankid)
|
|
|
|
|
ip = self.executor.start_server()
|
|
|
|
|
self.instance.set_ip(ip)
|
|
|
|
@ -270,6 +305,8 @@ class AsyncExecutor(object):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'instance is None, please run config_distributed_nodes init instance'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.dist_desc = dist_desc
|
|
|
|
|
place = core.CPUPlace()
|
|
|
|
|
executor = Executor(place)
|
|
|
|
|
executor.run(startup_program)
|
|
|
|
|