add data_generator package into setup.py

revert-16555-model_data_cryption_link_all_lib
dongdaxiang 7 years ago
parent 17790188d0
commit 8e14d8f900

@ -72,7 +72,11 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \
trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \
<<<<<<< HEAD
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
=======
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
>>>>>>> add data_generator package into setup.py
'io',
'initializer',
'layers',

@ -654,7 +654,7 @@ class Executor(object):
trainer._set_thread(thread)
trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer
return scope, trainer
def infer_from_dataset(self,
program=None,
@ -702,7 +702,7 @@ class Executor(object):
dataset=dataset)
"""
trainer = self._prepare_trainer(
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
@ -775,7 +775,7 @@ class Executor(object):
"""
trainer = self._prepare_trainer(
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,

@ -75,14 +75,14 @@ class MultiTrainer(TrainerDesc):
pass
def _set_program(self, program):
super(MultiTrainer, self).set_program(program)
super(MultiTrainer, self)._set_program(program)
self.program_ = program
def _gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_)
self.device_worker_.gen_worker_desc(self.proto_desc)
self.device_worker_._set_infer(self.infer_)
self.device_worker_._gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc):
@ -91,14 +91,14 @@ class DistMultiTrainer(TrainerDesc):
pass
def _set_program(self, program):
super(DistMultiTrainer, self).set_program(program)
super(DistMultiTrainer, self)._set_program(program)
self.program_ = program
def _gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc()
super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
print("None program")
self.device_worker_.set_infer(self.infer_)
self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc)
self.device_worker_._set_infer(self.infer_)
self.device_worker_._set_program(self.program_)
self.device_worker_._gen_worker_desc(self.proto_desc)

@ -29,13 +29,13 @@ class TrainerFactory(object):
# default is MultiTrainer + Hogwild
trainer = MultiTrainer()
device_worker = Hogwild()
trainer.set_device_worker(device_worker)
trainer._set_device_worker(device_worker)
else:
trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]()
device_worker.set_fleet_desc(opt_info["fleet_desc"])
trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"])
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(opt_info["fleet_desc"])
return trainer

@ -122,6 +122,7 @@ packages=['paddle',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details',
'paddle.fluid.incubate',
'paddle.fluid.incubate.data_generator',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',

Loading…
Cancel
Save