Merge pull request #7667 from abhinavarora/dist_transpiler

Remove optimize_op argument from get_pserver_program method of Distributed Transpiler
add_depthwiseConv_op_gpu
Darcy 7 years ago committed by GitHub
commit 7a68787667
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -407,7 +407,7 @@ class DistributeTranspiler:
outputs=opt_op.outputs,
attrs=opt_op.attrs)
def get_pserver_program(self, endpoint, optimize_ops):
def get_pserver_program(self, endpoint):
"""
get pserver side program by endpoint
@ -422,9 +422,9 @@ class DistributeTranspiler:
self._clone_var(pserver_program.global_block(), v)
# step6
optimize_sub_program = Program()
for idx, opt_op in enumerate(optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops,
idx)
for idx, opt_op in enumerate(self.optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint,
self.optimize_ops, idx)
if not is_op_on_pserver:
continue
if opt_op.inputs.has_key("Grad"):

@ -53,7 +53,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
else:

@ -197,7 +197,7 @@ def main():
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":

@ -87,7 +87,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":

@ -66,7 +66,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":

@ -60,7 +60,7 @@ if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":

@ -98,7 +98,7 @@ def main():
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()

Loading…
Cancel
Save