|
|
|
@ -113,9 +113,9 @@ class ListenAndServ(object):
|
|
|
|
|
which can receive variables from clients and run a block.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, endpoint, fan_in=1, optimizer_mode=True):
|
|
|
|
|
def __init__(self, endpoint, inputs, fan_in=1, optimizer_mode=True):
|
|
|
|
|
self.helper = LayerHelper("listen_and_serv")
|
|
|
|
|
self.inputs = []
|
|
|
|
|
self.inputs = inputs
|
|
|
|
|
self.outputs = []
|
|
|
|
|
self.endpoint = endpoint
|
|
|
|
|
self.fan_in = fan_in
|
|
|
|
@ -160,18 +160,13 @@ class ListenAndServ(object):
|
|
|
|
|
current_block = main_program.current_block()
|
|
|
|
|
parent_block = self.parent_block()
|
|
|
|
|
|
|
|
|
|
params, grads = self.get_params_and_grads()
|
|
|
|
|
param_names = [p.name for p in params]
|
|
|
|
|
grad_names = [g.name for g in grads]
|
|
|
|
|
parent_block.append_op(
|
|
|
|
|
type='listen_and_serv',
|
|
|
|
|
inputs={},
|
|
|
|
|
inputs={"X": self.inputs},
|
|
|
|
|
outputs={},
|
|
|
|
|
attrs={
|
|
|
|
|
'endpoint': self.endpoint,
|
|
|
|
|
'Fanin': self.fan_in,
|
|
|
|
|
'ParamList': param_names,
|
|
|
|
|
'GradList': grad_names,
|
|
|
|
|
'OptimizeBlock': current_block
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
@ -196,10 +191,14 @@ def Send(endpoints, send_vars, get_vars):
|
|
|
|
|
endpoints = list(set(epmap))
|
|
|
|
|
|
|
|
|
|
helper = LayerHelper("Send", **locals())
|
|
|
|
|
rpc_client_var = default_main_program().global_block().create_var(
|
|
|
|
|
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type="send",
|
|
|
|
|
inputs={"X": send_vars},
|
|
|
|
|
outputs={"Out": get_vars},
|
|
|
|
|
outputs={"Out": get_vars,
|
|
|
|
|
"RPCClient": rpc_client_var},
|
|
|
|
|
attrs={"endpoints": endpoints,
|
|
|
|
|
"epmap": epmap})
|
|
|
|
|
|
|
|
|
|