|
|
|
@ -32,19 +32,19 @@ class TrainerDesc(object):
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
# set default thread num == cpu count
|
|
|
|
|
self.proto_desc.thread_num = mp.cpu_count()
|
|
|
|
|
self.fleet_desc_ = None
|
|
|
|
|
self.device_worker_ = None
|
|
|
|
|
|
|
|
|
|
def set_thread(self, thread_num):
|
|
|
|
|
self.proto_desc.thread_num = thread_num
|
|
|
|
|
|
|
|
|
|
def set_filelist(self, filelist):
|
|
|
|
|
self.proto_desc.filelist.extend(filelist)
|
|
|
|
|
self.proto_desc.thread_num = min(
|
|
|
|
|
len(filelist), self.proto_desc.thread_num)
|
|
|
|
|
def set_device_worker(self, device_worker):
|
|
|
|
|
self.device_worker_ = device_worker
|
|
|
|
|
|
|
|
|
|
def set_data_feed(self, datafeed):
|
|
|
|
|
self.proto_desc.data_desc.CopyFrom(datafeed.proto_desc)
|
|
|
|
|
def set_fleet_desc(self, fleet_desc):
|
|
|
|
|
self.fleet_desc_ = fleet_desc
|
|
|
|
|
|
|
|
|
|
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker=None):
|
|
|
|
|
def gen_trainer_desc(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _desc(self):
|
|
|
|
@ -52,17 +52,14 @@ class TrainerDesc(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiTrainer(TrainerDesc):
|
|
|
|
|
def __init__(self, dataset=None, worker="Hogwild"):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(MultiTrainer, self).__init__()
|
|
|
|
|
if worker == "Hogwild":
|
|
|
|
|
self.proto_desc.device_worker_name = worker + "Worker"
|
|
|
|
|
self.proto_desc.class_name = "MultiTrainer"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('ValueError: DeviceWorker %s '
|
|
|
|
|
'is not supported in MultiTrainer' % worker)
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker="Hogwild"):
|
|
|
|
|
super(MultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
|
|
|
|
|
def gen_trainer_desc(self):
|
|
|
|
|
super(MultiTrainer, self).gen_trainer_desc()
|
|
|
|
|
self.proto_desc.class_name = "MultiTrainer"
|
|
|
|
|
self.device_worker_.gen_worker_desc(self.proto_desc, fleet_desc_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistMultiTrainer(TrainerDesc):
|
|
|
|
@ -70,14 +67,10 @@ class DistMultiTrainer(TrainerDesc):
|
|
|
|
|
super(DistMultiTrainer, self).__init__()
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def gen_trainer_desc(self, dataset=None, fleet_desc=None,
|
|
|
|
|
worker="Downpour"):
|
|
|
|
|
super(DistMultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
|
|
|
|
|
def gen_trainer_desc(self):
|
|
|
|
|
super(DistMultiTrainer, self).gen_trainer_desc()
|
|
|
|
|
self.proto_desc.class_name = "DistMultiTrainer"
|
|
|
|
|
self.proto_desc.data_desc.CopyFrom(dataset.proto_desc)
|
|
|
|
|
worker_builder = DeviceWorkerFactory()
|
|
|
|
|
device_worker = worker_builder.create_device_worker("Downpour")
|
|
|
|
|
device_worker.gen_worker_desc(self.proto_desc, fleet_desc)
|
|
|
|
|
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
|
|
|
|
|
|
|
|
|
|
def set_program_config(self, fleet_desc, program_id):
|
|
|
|
|
for program_config in fleet_desc.trainer_param.program_config:
|
|
|
|
|