|
|
|
@ -122,12 +122,14 @@ class Fleet(object):
|
|
|
|
|
print("You should run DistributedOptimizer.minimize() first")
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
def init_worker(self, program):
|
|
|
|
|
def init_worker(self, programs):
|
|
|
|
|
"""
|
|
|
|
|
init_worker(): will be called by user. When a user knows current process is_server(), he/she
|
|
|
|
|
should call init_worker() to initialize global information about worker and connect
|
|
|
|
|
worker with pserver.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(programs, list):
|
|
|
|
|
programs = [programs]
|
|
|
|
|
if self._opt_info:
|
|
|
|
|
if "fleet_desc" in self._opt_info:
|
|
|
|
|
self._dist_desc_str = text_format.MessageToString(
|
|
|
|
@ -145,14 +147,25 @@ class Fleet(object):
|
|
|
|
|
self.role_maker_.barrier_worker()
|
|
|
|
|
if self.role_maker_.is_first_worker():
|
|
|
|
|
tables = self._dist_desc.trainer_param.dense_table._values
|
|
|
|
|
for i in range(0, len(tables)):
|
|
|
|
|
table = tables[i];
|
|
|
|
|
var_name_list = []
|
|
|
|
|
for i in range(0, len(table.dense_variable_name)):
|
|
|
|
|
var_name_list.append(table.dense_variable_name[i])
|
|
|
|
|
#print "table id ", table.table_id
|
|
|
|
|
#print "var_name_list ", var_name_list
|
|
|
|
|
self._fleet_ptr.init_model(program.desc,
|
|
|
|
|
for prog in programs:
|
|
|
|
|
prog_id = str(id(prog))
|
|
|
|
|
prog_conf = self._opt_info['program_configs'][prog_id]
|
|
|
|
|
prog_tables = {}
|
|
|
|
|
for key in prog_conf:
|
|
|
|
|
if "dense" not in key:
|
|
|
|
|
continue
|
|
|
|
|
for table_id in prog_conf[key]:
|
|
|
|
|
prog_tables[int(table_id)] = 0
|
|
|
|
|
for i in range(0, len(tables)):
|
|
|
|
|
table = tables[i]
|
|
|
|
|
if int(table.table_id) not in prog_tables:
|
|
|
|
|
continue
|
|
|
|
|
var_name_list = []
|
|
|
|
|
for i in range(0, len(table.dense_variable_name)):
|
|
|
|
|
var_name_list.append(table.dense_variable_name[i])
|
|
|
|
|
#print "table id ", table.table_id
|
|
|
|
|
#print "var_name_list ", var_name_list
|
|
|
|
|
self._fleet_ptr.init_model(prog.desc,
|
|
|
|
|
int(table.table_id),
|
|
|
|
|
var_name_list)
|
|
|
|
|
self.role_maker_.barrier_worker()
|
|
|
|
|