|
|
|
@ -106,13 +106,18 @@ class ParallelExecutor(object):
|
|
|
|
|
else framework.default_main_program()
|
|
|
|
|
|
|
|
|
|
self._compiled_program = compiler.CompiledProgram(main_program)
|
|
|
|
|
if share_vars_from:
|
|
|
|
|
assert isinstance(
|
|
|
|
|
share_vars_from, ParallelExecutor
|
|
|
|
|
), "The share_vars_from should be ParallelExecutor."
|
|
|
|
|
self._compiled_program.with_data_parallel(
|
|
|
|
|
loss_name=loss_name,
|
|
|
|
|
build_strategy=build_strategy,
|
|
|
|
|
exec_strategy=exec_strategy,
|
|
|
|
|
share_vars_from=share_vars_from)
|
|
|
|
|
share_vars_from=share_vars_from._compiled_program
|
|
|
|
|
if share_vars_from else None)
|
|
|
|
|
self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
|
|
|
|
|
self._executor = executor.Executor(self._place)
|
|
|
|
|
self._exe = executor.Executor(self._place)
|
|
|
|
|
self._compiled_program._compile(place=self._place, scope=self._scope)
|
|
|
|
|
|
|
|
|
|
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
|
|
|
|
@ -180,11 +185,11 @@ class ParallelExecutor(object):
|
|
|
|
|
loss = pe.run(feed=feeder.feed(cur_batch),
|
|
|
|
|
fetch_list=[avg_cost.name]))
|
|
|
|
|
"""
|
|
|
|
|
return self._executor.run(program=self._compiled_program,
|
|
|
|
|
scope=self._scope,
|
|
|
|
|
feed=feed,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
return_numpy=return_numpy)
|
|
|
|
|
return self._exe.run(program=self._compiled_program,
|
|
|
|
|
scope=self._scope,
|
|
|
|
|
feed=feed,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
return_numpy=return_numpy)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device_count(self):
|
|
|
|
|