|
|
|
@ -143,7 +143,8 @@ class DistributeTranspiler:
|
|
|
|
|
program=None,
|
|
|
|
|
pservers="127.0.0.1:6174",
|
|
|
|
|
trainers=1,
|
|
|
|
|
split_method=splitter.round_robin):
|
|
|
|
|
split_method=splitter.round_robin,
|
|
|
|
|
sync_mode=True):
|
|
|
|
|
"""
|
|
|
|
|
Transpile the program to distributed data-parallelism programs.
|
|
|
|
|
The main_program will be transformed to use a remote parameter server
|
|
|
|
@ -191,6 +192,7 @@ class DistributeTranspiler:
|
|
|
|
|
self.origin_program = program
|
|
|
|
|
self.trainer_num = trainers
|
|
|
|
|
self.optimize_ops = optimize_ops
|
|
|
|
|
self.sync_mode = sync_mode
|
|
|
|
|
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
|
|
|
|
|
# like Kubernetes, we should port this to use etcd later when developing
|
|
|
|
|
# fluid distributed training with fault-tolerance.
|
|
|
|
@ -473,7 +475,9 @@ class DistributeTranspiler:
|
|
|
|
|
"OptimizeBlock": optimize_block,
|
|
|
|
|
"endpoint": endpoint,
|
|
|
|
|
"Fanin": self.trainer_num,
|
|
|
|
|
"PrefetchBlock": prefetch_block
|
|
|
|
|
"PrefetchBlock": prefetch_block,
|
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
"grad_to_id": []
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
pserver_program.sync_with_cpp()
|
|
|
|
|