|
|
|
@ -14,8 +14,10 @@
|
|
|
|
|
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..layer_helper import LayerHelper
|
|
|
|
|
from control_flow import BlockGuard
|
|
|
|
|
from ..layer_helper import LayerHelper
|
|
|
|
|
|
|
|
|
|
__all__ = ['data']
|
|
|
|
|
__all__ = ['data', 'BlockGuardServ', 'ListenAndServ', 'Send']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def data(name,
|
|
|
|
@ -105,12 +107,14 @@ class ListenAndServ(object):
|
|
|
|
|
which can receive variables from clients and run a block.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, endpoint, fan_in=1):
|
|
|
|
|
self.helper = LayerHelper("recv", name=name)
|
|
|
|
|
def __init__(self, endpoint, fan_in=1, optimizer_mode=True):
|
|
|
|
|
self.helper = LayerHelper("recv")
|
|
|
|
|
self.inputs = []
|
|
|
|
|
self.outputs = []
|
|
|
|
|
self.endpoint = endpoint
|
|
|
|
|
self.fan_in = fan_in
|
|
|
|
|
# FIXME(typhoonzero): Add this switch is stupid
|
|
|
|
|
self.optimizer_mode = optimizer_mode
|
|
|
|
|
|
|
|
|
|
def do(self):
|
|
|
|
|
return BlockGuardServ(self)
|
|
|
|
@ -124,9 +128,16 @@ class ListenAndServ(object):
|
|
|
|
|
grads = list()
|
|
|
|
|
for op in current_block.ops:
|
|
|
|
|
# FIXME(typhoonzero): op.inputs is None if it's cloned.
|
|
|
|
|
if "Grad" in op.inputs and "Param" in op.inputs:
|
|
|
|
|
params.append(op.inputs["Param"].name)
|
|
|
|
|
grads.append(op.inputs["Grad"].name)
|
|
|
|
|
if self.optimizer_mode:
|
|
|
|
|
if "Grad" in op.inputs and "Param" in op.inputs:
|
|
|
|
|
params.append(op.inputs["Param"].name)
|
|
|
|
|
grads.append(op.inputs["Grad"].name)
|
|
|
|
|
else:
|
|
|
|
|
# simple recv mode, recv operators inputs.
|
|
|
|
|
for iname in op.input_names:
|
|
|
|
|
for in_var_name in op.input(iname):
|
|
|
|
|
params.append(parent_block.var(name))
|
|
|
|
|
grads.append(parent_block.var(name))
|
|
|
|
|
|
|
|
|
|
return params, grads
|
|
|
|
|
|
|
|
|
|