Fix a minor bug for distributed_spliter.round_robin

Also fixed typo and comments.
fea/docker_cudnn7
xuwei06 7 years ago
parent b2a1c9e8b7
commit 560d960b27

@ -17,7 +17,7 @@ import framework
from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer
from layer_helper import LayerHelper
from distributed_spliter import *
import distributed_splitter as splitter
import math
from . import core
import debuger
@ -138,7 +138,7 @@ class DistributeTranspiler:
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=round_robin):
split_method=splitter.round_robin):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server

@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints):
"""
hash variable names to several endpoints.
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
"""
def _hash_block(block_str, total):
@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints):
def round_robin(varlist, pserver_endpoints):
"""
distribute variables to several endpoints.
Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
"""
assert (len(varlist) > len(pserver_endpoints))
assert (len(varlist) >= len(pserver_endpoints))
eplist = []
pserver_idx = 0
Loading…
Cancel
Save