|
|
|
@ -102,6 +102,7 @@ class Hogwild(DeviceWorker):
|
|
|
|
|
|
|
|
|
|
program_configs = opt_info["program_configs"]
|
|
|
|
|
downpour = trainer_desc.downpour_param
|
|
|
|
|
hogwild = trainer_desc.hogwild_param
|
|
|
|
|
|
|
|
|
|
for pid in program_configs:
|
|
|
|
|
if pid == program_id:
|
|
|
|
@ -154,6 +155,7 @@ class Hogwild(DeviceWorker):
|
|
|
|
|
sparse_table.label_var_name = ""
|
|
|
|
|
if opt_info["stat_var_names"]:
|
|
|
|
|
for i in opt_info["stat_var_names"]:
|
|
|
|
|
hogwild.stat_var_names.extend([i])
|
|
|
|
|
downpour.stat_var_names.extend([i])
|
|
|
|
|
|
|
|
|
|
for i in worker.get_desc().dense_table:
|
|
|
|
@ -163,10 +165,10 @@ class Hogwild(DeviceWorker):
|
|
|
|
|
dense_table.dense_value_name.extend(i.dense_variable_name)
|
|
|
|
|
dense_table.dense_grad_name.extend(
|
|
|
|
|
i.dense_gradient_variable_name)
|
|
|
|
|
downpour.skip_ops.extend(worker.get_desc().skip_op)
|
|
|
|
|
hogwild.skip_ops.extend(worker.get_desc().skip_op)
|
|
|
|
|
if self._infer:
|
|
|
|
|
downpour.push_dense = False
|
|
|
|
|
downpour.push_sparse = False
|
|
|
|
|
hogwild.skip_ops.extend(
|
|
|
|
|
["push_sparse", "push_sparse_v2", "push_dense"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DownpourSGD(DeviceWorker):
|
|
|
|
|