parent
ed55f1b9d4
commit
c70ea1cc30
@ -0,0 +1,38 @@
|
||||
def hash_name(varblocks, pserver_endpoints):
|
||||
"""
|
||||
:param varblocks: a list of VarBlock string indicating
|
||||
sub blocks of variables
|
||||
:return: a map of pserver endpoint -> varblock_str
|
||||
"""
|
||||
|
||||
def _hash_block(block_str, total):
|
||||
return hash(block_str) % total
|
||||
|
||||
ep2block = dict()
|
||||
for varblock_str in varblocks:
|
||||
if param.trainable is True and grad is not None:
|
||||
server_id = _hash_block(varblock_str, len(pserver_endpoints))
|
||||
server_for_param = pserver_endpoints[server_id]
|
||||
if not ep2block.has_key(server_for_param):
|
||||
ep2block[server_for_param] = []
|
||||
ep2block[server_for_param].append(varblock_str)
|
||||
|
||||
return ep2block
|
||||
|
||||
|
||||
def round_robin(varblocks, pserver_endpoints):
|
||||
assert (len(varblocks) > len(pserver_endpoints))
|
||||
|
||||
ep2block = dict()
|
||||
pserver_idx = 0
|
||||
for varblock_str in varblocks:
|
||||
if param.trainable is True:
|
||||
server_for_param = pserver_endpoints[pserver_idx]
|
||||
if not ep2block.has_key(server_for_param):
|
||||
ep2block[server_for_param] = []
|
||||
ep2block[server_for_param].append(varblock_str)
|
||||
|
||||
pserver_idx += 1
|
||||
if pserver_idx >= len(pserver_endpoints):
|
||||
pserver_idx = 0
|
||||
return ep2block
|
Loading…
Reference in new issue