|
|
@ -126,12 +126,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
[optimize_ops, grads_and_weights]
|
|
|
|
[optimize_ops, grads_and_weights]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
table_name = self._find_multi_distributed_lookup_table(losses)
|
|
|
|
sparse_table_names = self._find_multi_distributed_lookup_table(losses)
|
|
|
|
inputs_dict = self._find_distributed_lookup_table_inputs(
|
|
|
|
inputs_dict = self._find_distributed_lookup_table_inputs(
|
|
|
|
losses[0].block.program, table_name)
|
|
|
|
losses[0].block.program, sparse_table_names)
|
|
|
|
|
|
|
|
|
|
|
|
outputs_dict = self._find_distributed_lookup_table_outputs(
|
|
|
|
outputs_dict = self._find_distributed_lookup_table_outputs(
|
|
|
|
losses[0].block.program, table_name)
|
|
|
|
losses[0].block.program, sparse_table_names)
|
|
|
|
|
|
|
|
|
|
|
|
ps_param = pslib.PSParameter()
|
|
|
|
ps_param = pslib.PSParameter()
|
|
|
|
server = DownpourServer()
|
|
|
|
server = DownpourServer()
|
|
|
@ -147,7 +147,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
worker.get_desc().CopyFrom(ps_param.trainer_param)
|
|
|
|
worker.get_desc().CopyFrom(ps_param.trainer_param)
|
|
|
|
|
|
|
|
|
|
|
|
sparse_table_index = 0
|
|
|
|
sparse_table_index = 0
|
|
|
|
for tn in table_name:
|
|
|
|
for tn in sparse_table_names:
|
|
|
|
if strategy.get(tn) is not None:
|
|
|
|
if strategy.get(tn) is not None:
|
|
|
|
server.add_sparse_table(sparse_table_index, strategy[tn])
|
|
|
|
server.add_sparse_table(sparse_table_index, strategy[tn])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -199,13 +199,14 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
|
|
|
|
|
|
|
|
if strategy.get('dense_table') is not None:
|
|
|
|
if strategy.get('dense_table') is not None:
|
|
|
|
server.add_dense_table(dense_table_index, params, grads,
|
|
|
|
server.add_dense_table(dense_table_index, params, grads,
|
|
|
|
strategy['dense_table'], table_name)
|
|
|
|
strategy['dense_table'],
|
|
|
|
|
|
|
|
sparse_table_names)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
server.add_dense_table(dense_table_index, params, grads, None,
|
|
|
|
server.add_dense_table(dense_table_index, params, grads, None,
|
|
|
|
table_name)
|
|
|
|
sparse_table_names)
|
|
|
|
worker.add_dense_table(dense_table_index, self._learning_rate,
|
|
|
|
worker.add_dense_table(dense_table_index, self._learning_rate,
|
|
|
|
params, grads, dense_start_table_id,
|
|
|
|
params, grads, dense_start_table_id,
|
|
|
|
table_name)
|
|
|
|
sparse_table_names)
|
|
|
|
program_configs[program_id]["pull_dense"] = [dense_table_index]
|
|
|
|
program_configs[program_id]["pull_dense"] = [dense_table_index]
|
|
|
|
program_configs[program_id]["push_dense"] = [dense_table_index]
|
|
|
|
program_configs[program_id]["push_dense"] = [dense_table_index]
|
|
|
|
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
|
|
|
|
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
|
|
|
@ -214,15 +215,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
server.add_data_norm_table(
|
|
|
|
server.add_data_norm_table(
|
|
|
|
dense_table_index, self._learning_rate,
|
|
|
|
dense_table_index, self._learning_rate,
|
|
|
|
data_norm_params, data_norm_grads,
|
|
|
|
data_norm_params, data_norm_grads,
|
|
|
|
strategy['datanorm_table'], table_name)
|
|
|
|
strategy['datanorm_table'], sparse_table_names)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
server.add_data_norm_table(
|
|
|
|
server.add_data_norm_table(
|
|
|
|
dense_table_index, self._learning_rate,
|
|
|
|
dense_table_index, self._learning_rate,
|
|
|
|
data_norm_params, data_norm_grads, None, table_name)
|
|
|
|
data_norm_params, data_norm_grads, None,
|
|
|
|
|
|
|
|
sparse_table_names)
|
|
|
|
|
|
|
|
|
|
|
|
worker.add_dense_table(dense_table_index, self._learning_rate,
|
|
|
|
worker.add_dense_table(dense_table_index, self._learning_rate,
|
|
|
|
data_norm_params, data_norm_grads,
|
|
|
|
data_norm_params, data_norm_grads,
|
|
|
|
dense_start_table_id, table_name)
|
|
|
|
dense_start_table_id, sparse_table_names)
|
|
|
|
program_configs[program_id]["pull_dense"].extend(
|
|
|
|
program_configs[program_id]["pull_dense"].extend(
|
|
|
|
[dense_table_index])
|
|
|
|
[dense_table_index])
|
|
|
|
program_configs[program_id]["push_dense"].extend(
|
|
|
|
program_configs[program_id]["push_dense"].extend(
|
|
|
|