|
|
@ -25,6 +25,7 @@ from google.protobuf import text_format
|
|
|
|
from . import io
|
|
|
|
from . import io
|
|
|
|
from .data_feed_desc import DataFeedDesc
|
|
|
|
from .data_feed_desc import DataFeedDesc
|
|
|
|
from .distributed import ps_instance
|
|
|
|
from .distributed import ps_instance
|
|
|
|
|
|
|
|
from .contrib.utils import hdfs_utils as hdfs
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['AsyncExecutor']
|
|
|
|
__all__ = ['AsyncExecutor']
|
|
|
|
|
|
|
|
|
|
|
@ -152,6 +153,22 @@ class AsyncExecutor(object):
|
|
|
|
data_feed.desc(), filelist, thread_num,
|
|
|
|
data_feed.desc(), filelist, thread_num,
|
|
|
|
fetch_var_names, debug)
|
|
|
|
fetch_var_names, debug)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12):
|
|
|
|
|
|
|
|
hadoop_home = "$HADOOP_HOME"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs = {
|
|
|
|
|
|
|
|
"fs.default.name": fs_default_name,
|
|
|
|
|
|
|
|
"hadoop.job.ugi": ugi
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client = hdfs.HDFSClient(hadoop_home, configs)
|
|
|
|
|
|
|
|
downloads = hdfs.multi_download(
|
|
|
|
|
|
|
|
client,
|
|
|
|
|
|
|
|
afs_path,
|
|
|
|
|
|
|
|
local_path,
|
|
|
|
|
|
|
|
self.instance.get_worker_index(),
|
|
|
|
|
|
|
|
self.instance.get_node_cnt() / 2,
|
|
|
|
|
|
|
|
multi_processes=process_num)
|
|
|
|
|
|
|
|
|
|
|
|
def config_distributed_nodes(self, dist_opt):
|
|
|
|
def config_distributed_nodes(self, dist_opt):
|
|
|
|
|
|
|
|
|
|
|
@ -179,10 +196,11 @@ class AsyncExecutor(object):
|
|
|
|
self.executor.gather_servers(ips, self.instance.get_node_cnt())
|
|
|
|
self.executor.gather_servers(ips, self.instance.get_node_cnt())
|
|
|
|
self.instance.barrier_all() #wait all worker start
|
|
|
|
self.instance.barrier_all() #wait all worker start
|
|
|
|
self.instance.barrier_all() #wait init model
|
|
|
|
self.instance.barrier_all() #wait init model
|
|
|
|
|
|
|
|
self.instance.barrier_all() #wait for download_data
|
|
|
|
self.instance.barrier_all() #wait worker do all things
|
|
|
|
self.instance.barrier_all() #wait worker do all things
|
|
|
|
self.instance.barrier_all() #sync
|
|
|
|
self.instance.barrier_all() #sync
|
|
|
|
|
|
|
|
|
|
|
|
def init_worker(self, dist_desc):
|
|
|
|
def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi):
|
|
|
|
self.instance.barrier_all() #wait all server start
|
|
|
|
self.instance.barrier_all() #wait all server start
|
|
|
|
ips = self.instance.gather_ips()
|
|
|
|
ips = self.instance.gather_ips()
|
|
|
|
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
|
|
|
|
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
|
|
|
@ -190,6 +208,8 @@ class AsyncExecutor(object):
|
|
|
|
if self.instance.is_first_worker():
|
|
|
|
if self.instance.is_first_worker():
|
|
|
|
self.executor.init_model()
|
|
|
|
self.executor.init_model()
|
|
|
|
self.instance.barrier_all() #wait init model
|
|
|
|
self.instance.barrier_all() #wait init model
|
|
|
|
|
|
|
|
self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12)
|
|
|
|
|
|
|
|
self.instance.barrier_all() #wait for download_data
|
|
|
|
|
|
|
|
|
|
|
|
def init_model(self):
|
|
|
|
def init_model(self):
|
|
|
|
self.executor.init_model()
|
|
|
|
self.executor.init_model()
|
|
|
|