|
|
|
@ -167,6 +167,113 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
|
ret_list.append(x[0])
|
|
|
|
|
return ret_list
|
|
|
|
|
|
|
|
|
|
def _if_last_block(self, op, _equal_dict):
|
|
|
|
|
# for conditional_block op
|
|
|
|
|
cond_str = op.input('Cond')[0]
|
|
|
|
|
bool_test = False
|
|
|
|
|
if cond_str.startswith('equal'):
|
|
|
|
|
bool_test = True
|
|
|
|
|
vars_ = op.input('Input')
|
|
|
|
|
equal_keys = _equal_dict.keys()
|
|
|
|
|
for var_cond in vars_:
|
|
|
|
|
if var_cond in equal_keys:
|
|
|
|
|
if bool_test:
|
|
|
|
|
print("the conditional block is error")
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _generte_cond_para_map(self, op, _fill_value_dict, _equal_fill_dict,
|
|
|
|
|
_now_program, _all_params):
|
|
|
|
|
# generate cond value to parameter map recursively
|
|
|
|
|
cond_str = op.input('Cond')[0]
|
|
|
|
|
vars_ = op.input('Input')
|
|
|
|
|
|
|
|
|
|
if self._if_last_block(op, _equal_fill_dict):
|
|
|
|
|
vars_ = op.input('Input')
|
|
|
|
|
cond_key = ""
|
|
|
|
|
if cond_str.startswith('equal'):
|
|
|
|
|
cond_key = int(_fill_value_dict[_equal_fill_dict[cond_str]])
|
|
|
|
|
else:
|
|
|
|
|
cond_key = -1
|
|
|
|
|
p_list = []
|
|
|
|
|
for var_cond in vars_:
|
|
|
|
|
if var_cond in _all_params:
|
|
|
|
|
p_list.append(var_cond)
|
|
|
|
|
|
|
|
|
|
self._cond_params[cond_key] = p_list
|
|
|
|
|
self._other_params.extend(p_list)
|
|
|
|
|
else:
|
|
|
|
|
ops_cond = _now_program.block(int(op.attr('sub_block').id)).ops
|
|
|
|
|
for op in ops_cond:
|
|
|
|
|
if op.type == 'conditional_block':
|
|
|
|
|
self._generte_cond_para_map(op, _fill_value_dict,
|
|
|
|
|
_equal_fill_dict, _now_program,
|
|
|
|
|
_all_params)
|
|
|
|
|
|
|
|
|
|
def _has_conditional_block(self, loss):
|
|
|
|
|
now_program = loss.block.program
|
|
|
|
|
root_block = now_program.block(0)
|
|
|
|
|
ops_ = root_block.ops
|
|
|
|
|
for op in ops_:
|
|
|
|
|
if op.type == 'conditional_block':
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _check_params_grads(self, params, grads):
|
|
|
|
|
if len(params) != len(grads):
|
|
|
|
|
raise ValueError("params size != grads size, %s vs %s" %
|
|
|
|
|
(len(params), len(grads)))
|
|
|
|
|
|
|
|
|
|
pname2grad = dict()
|
|
|
|
|
for i in range(len(params)):
|
|
|
|
|
pname = params[i].name
|
|
|
|
|
gname = grads[i].name
|
|
|
|
|
if pname != gname[:-5]:
|
|
|
|
|
raise ValueError(" params != grads , %s vs %s" % (pname, gname))
|
|
|
|
|
pname2grad[pname] = grads[i]
|
|
|
|
|
|
|
|
|
|
return pname2grad
|
|
|
|
|
|
|
|
|
|
def _generate_multi_dense_table(self,
|
|
|
|
|
params,
|
|
|
|
|
grads,
|
|
|
|
|
cond_params,
|
|
|
|
|
other_params,
|
|
|
|
|
sparse_table_names,
|
|
|
|
|
dense_table_id=0):
|
|
|
|
|
# generate multi dense table by cond value
|
|
|
|
|
pname2grad = self._check_params_grads(params, grads)
|
|
|
|
|
root_params_list = []
|
|
|
|
|
root_grads_list = []
|
|
|
|
|
dense_tables = []
|
|
|
|
|
for i, p in enumerate(params):
|
|
|
|
|
if p.name not in other_params and p.name not in sparse_table_names:
|
|
|
|
|
root_params_list.append(p)
|
|
|
|
|
root_grads_list.append(grads[i])
|
|
|
|
|
if len(root_params_list) > 0:
|
|
|
|
|
dense_tables.append(dense_table_id)
|
|
|
|
|
dense_table_id += 1
|
|
|
|
|
lists_params = [[] for i in range(len(cond_params.keys()))]
|
|
|
|
|
lists_grads = [[] for i in range(len(cond_params.keys()))]
|
|
|
|
|
|
|
|
|
|
key_id = 0
|
|
|
|
|
name2key = dict()
|
|
|
|
|
cond2denseid = dict()
|
|
|
|
|
for key, value in cond_params.items():
|
|
|
|
|
cond2denseid[key] = dense_table_id
|
|
|
|
|
dense_tables.append(dense_table_id)
|
|
|
|
|
dense_table_id += 1
|
|
|
|
|
for v in value:
|
|
|
|
|
name2key[v] = key_id
|
|
|
|
|
key_id += 1
|
|
|
|
|
|
|
|
|
|
for p in params:
|
|
|
|
|
if p.name in other_params:
|
|
|
|
|
lists_params[name2key[p.name]].append(p)
|
|
|
|
|
lists_grads[name2key[p.name]].append(pname2grad[p.name])
|
|
|
|
|
|
|
|
|
|
return dense_tables, cond2denseid, lists_params, lists_grads, root_params_list, root_grads_list
|
|
|
|
|
|
|
|
|
|
def _minimize(self,
|
|
|
|
|
losses,
|
|
|
|
|
startup_program=None,
|
|
|
|
@ -215,6 +322,31 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
|
no_grad_set),
|
|
|
|
|
key=lambda x: x[0].name)
|
|
|
|
|
|
|
|
|
|
# has condition_block op means multi-task
|
|
|
|
|
flag_multi_task = self._has_conditional_block(loss)
|
|
|
|
|
if flag_multi_task:
|
|
|
|
|
self._cond_params = dict()
|
|
|
|
|
self._other_params = []
|
|
|
|
|
now_program = loss.block.program
|
|
|
|
|
root_block = now_program.block(0)
|
|
|
|
|
all_params = []
|
|
|
|
|
for par in root_block.all_parameters():
|
|
|
|
|
all_params.append(par.name)
|
|
|
|
|
|
|
|
|
|
ops_ = root_block.ops
|
|
|
|
|
fill_value_dict = dict()
|
|
|
|
|
equal_fill_dict = dict()
|
|
|
|
|
for op in ops_:
|
|
|
|
|
# conditional_block op must has fill_constant and equal op
|
|
|
|
|
if op.type == 'fill_constant':
|
|
|
|
|
fill_value_dict[op.output('Out')[0]] = op.attr('value')
|
|
|
|
|
if op.type == 'equal':
|
|
|
|
|
equal_fill_dict[op.output('Out')[0]] = op.input('Y')[0]
|
|
|
|
|
if op.type == 'conditional_block':
|
|
|
|
|
self._generte_cond_para_map(op, fill_value_dict,
|
|
|
|
|
equal_fill_dict,
|
|
|
|
|
now_program, all_params)
|
|
|
|
|
|
|
|
|
|
if prog_id not in program_id_set:
|
|
|
|
|
program_id_set.add(prog_id)
|
|
|
|
|
sparse_table = self._find_multi_distributed_lookup_table([loss])
|
|
|
|
@ -402,17 +534,65 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
|
data_norm_grads.append(i[1])
|
|
|
|
|
if not is_data_norm_data:
|
|
|
|
|
grads.append(i[1])
|
|
|
|
|
# for new dense table
|
|
|
|
|
multi_task_dense_tables_push = []
|
|
|
|
|
multi_task_dense_tables_pull = []
|
|
|
|
|
if flag_multi_task:
|
|
|
|
|
dense_tables, cond2denseid, lists_params, lists_grads, root_params_list, root_grads_list = self._generate_multi_dense_table(
|
|
|
|
|
params, grads, self._cond_params,
|
|
|
|
|
self._other_params, sparse_table_names,
|
|
|
|
|
dense_table_index)
|
|
|
|
|
program_configs[program_id][
|
|
|
|
|
'cond2denseid'] = cond2denseid
|
|
|
|
|
multi_task_dense_tables_push = dense_tables
|
|
|
|
|
multi_task_dense_tables_pull = dense_tables[:]
|
|
|
|
|
|
|
|
|
|
if strategy.get('dense_table') is not None:
|
|
|
|
|
server.add_dense_table(dense_table_index, params, grads,
|
|
|
|
|
strategy['dense_table'],
|
|
|
|
|
if flag_multi_task:
|
|
|
|
|
server_dense_table_index = dense_table_index
|
|
|
|
|
if len(root_params_list) > 0:
|
|
|
|
|
server.add_dense_table(
|
|
|
|
|
server_dense_table_index, root_params_list,
|
|
|
|
|
root_grads_list, strategy['dense_table'],
|
|
|
|
|
sparse_table_names)
|
|
|
|
|
server_dense_table_index += 1
|
|
|
|
|
|
|
|
|
|
for i in range(len(lists_params)):
|
|
|
|
|
server.add_dense_table(
|
|
|
|
|
server_dense_table_index, lists_params[i],
|
|
|
|
|
lists_grads[i], strategy['dense_table'],
|
|
|
|
|
sparse_table_names)
|
|
|
|
|
server_dense_table_index += 1
|
|
|
|
|
else:
|
|
|
|
|
server.add_dense_table(
|
|
|
|
|
dense_table_index, params, grads,
|
|
|
|
|
strategy['dense_table'], sparse_table_names)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
server.add_dense_table(dense_table_index, params, grads,
|
|
|
|
|
None, sparse_table_names)
|
|
|
|
|
|
|
|
|
|
if flag_multi_task:
|
|
|
|
|
|
|
|
|
|
if len(root_params_list) > 0:
|
|
|
|
|
worker.add_dense_table(
|
|
|
|
|
dense_table_index, self._learning_rate, params, grads,
|
|
|
|
|
dense_table_index, self._learning_rate,
|
|
|
|
|
root_params_list, root_grads_list,
|
|
|
|
|
dense_start_table_id, sparse_table_names)
|
|
|
|
|
dense_table_index += 1
|
|
|
|
|
|
|
|
|
|
for i in range(len(lists_params)):
|
|
|
|
|
worker.add_dense_table(
|
|
|
|
|
dense_table_index, self._learning_rate,
|
|
|
|
|
lists_params[i], lists_grads[i],
|
|
|
|
|
dense_start_table_id, sparse_table_names)
|
|
|
|
|
dense_table_index += 1
|
|
|
|
|
|
|
|
|
|
dense_table_index -= 1
|
|
|
|
|
else:
|
|
|
|
|
worker.add_dense_table(
|
|
|
|
|
dense_table_index, self._learning_rate, params,
|
|
|
|
|
grads, dense_start_table_id, sparse_table_names)
|
|
|
|
|
|
|
|
|
|
if FLEET_GLOBAL_DICT["enable"]:
|
|
|
|
|
cur_prog = losses[loss_index].block.program
|
|
|
|
@ -430,15 +610,28 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
|
program_id] and "push_dense" in program_configs[
|
|
|
|
|
program_id] and len(program_configs[program_id][
|
|
|
|
|
"pull_dense"]) > 0:
|
|
|
|
|
if flag_multi_task:
|
|
|
|
|
program_configs[program_id]["pull_dense"].extend(
|
|
|
|
|
multi_task_dense_tables_pull)
|
|
|
|
|
program_configs[program_id]["push_dense"].extend(
|
|
|
|
|
multi_task_dense_tables_push)
|
|
|
|
|
else:
|
|
|
|
|
program_configs[program_id]["pull_dense"].extend(
|
|
|
|
|
[dense_table_index])
|
|
|
|
|
program_configs[program_id]["push_dense"].extend(
|
|
|
|
|
[dense_table_index])
|
|
|
|
|
else:
|
|
|
|
|
if flag_multi_task:
|
|
|
|
|
program_configs[program_id][
|
|
|
|
|
"pull_dense"] = multi_task_dense_tables_pull
|
|
|
|
|
program_configs[program_id][
|
|
|
|
|
"push_dense"] = multi_task_dense_tables_push
|
|
|
|
|
else:
|
|
|
|
|
program_configs[program_id][
|
|
|
|
|
"pull_dense"] = [dense_table_index]
|
|
|
|
|
program_configs[program_id][
|
|
|
|
|
"push_dense"] = [dense_table_index]
|
|
|
|
|
|
|
|
|
|
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
|
|
|
|
|
dense_table_index += 1
|
|
|
|
|
if strategy.get('datanorm_table') is not None:
|
|
|
|
|