|
|
|
@ -241,6 +241,35 @@ class Fleet(object):
|
|
|
|
|
"""
|
|
|
|
|
self._fleet_ptr.save_model(save_path)
|
|
|
|
|
|
|
|
|
|
def split_filelist(self, filelist):
|
|
|
|
|
"""
|
|
|
|
|
split filelist before distributed training,
|
|
|
|
|
for example, filelist is [a, b, c ,d, e] and trainer_num = 2,
|
|
|
|
|
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e]
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
filelist(list): list of filename, can be local or hdfs/afs.
|
|
|
|
|
|
|
|
|
|
Returns: list of filename which belongs to this trainer.
|
|
|
|
|
"""
|
|
|
|
|
file_num = len(filelist)
|
|
|
|
|
trainer_id = self.get_worker_index()
|
|
|
|
|
trainer_num = self.get_worker_num()
|
|
|
|
|
if trainer_num > file_num:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"trainer_num should be <= file_num : "
|
|
|
|
|
"%s > %s" % (trainer_num, file_num)
|
|
|
|
|
)
|
|
|
|
|
# get interval of filelist, it's [ )
|
|
|
|
|
start = 0
|
|
|
|
|
end = 0
|
|
|
|
|
for i in range(0, trainer_id + 1):
|
|
|
|
|
length = file_num / trainer_num + (i < (file_num % trainer_num))
|
|
|
|
|
start = end
|
|
|
|
|
end += length
|
|
|
|
|
myfilelist = filelist[start : end]
|
|
|
|
|
return myfilelist
|
|
|
|
|
|
|
|
|
|
def _set_opt_info(self, opt_info):
|
|
|
|
|
"""
|
|
|
|
|
this function saves the result from DistributedOptimizer.minimize()
|
|
|
|
@ -337,3 +366,4 @@ save_pserver_model = fleet_instance.save_pserver_model
|
|
|
|
|
worker_num = fleet_instance.get_worker_num
|
|
|
|
|
server_num = fleet_instance.get_server_num
|
|
|
|
|
worker_index = fleet_instance.get_worker_index
|
|
|
|
|
split_filelist = fleet_instance.split_filelist
|
|
|
|
|