|
|
|
@ -75,14 +75,14 @@ class MultiTrainer(TrainerDesc):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _set_program(self, program):
|
|
|
|
|
super(MultiTrainer, self).set_program(program)
|
|
|
|
|
super(MultiTrainer, self)._set_program(program)
|
|
|
|
|
self.program_ = program
|
|
|
|
|
|
|
|
|
|
def _gen_trainer_desc(self):
|
|
|
|
|
super(MultiTrainer, self).gen_trainer_desc()
|
|
|
|
|
super(MultiTrainer, self)._gen_trainer_desc()
|
|
|
|
|
self.proto_desc.class_name = "MultiTrainer"
|
|
|
|
|
self.device_worker_.set_infer(self.infer_)
|
|
|
|
|
self.device_worker_.gen_worker_desc(self.proto_desc)
|
|
|
|
|
self.device_worker_._set_infer(self.infer_)
|
|
|
|
|
self.device_worker_._gen_worker_desc(self.proto_desc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistMultiTrainer(TrainerDesc):
|
|
|
|
@ -91,14 +91,14 @@ class DistMultiTrainer(TrainerDesc):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _set_program(self, program):
|
|
|
|
|
super(DistMultiTrainer, self).set_program(program)
|
|
|
|
|
super(DistMultiTrainer, self)._set_program(program)
|
|
|
|
|
self.program_ = program
|
|
|
|
|
|
|
|
|
|
def _gen_trainer_desc(self):
|
|
|
|
|
super(DistMultiTrainer, self).gen_trainer_desc()
|
|
|
|
|
super(DistMultiTrainer, self)._gen_trainer_desc()
|
|
|
|
|
self.proto_desc.class_name = "DistMultiTrainer"
|
|
|
|
|
if self.program_ == None:
|
|
|
|
|
print("None program")
|
|
|
|
|
self.device_worker_.set_infer(self.infer_)
|
|
|
|
|
self.device_worker_.set_program(self.program_)
|
|
|
|
|
self.device_worker_.gen_worker_desc(self.proto_desc)
|
|
|
|
|
self.device_worker_._set_infer(self.infer_)
|
|
|
|
|
self.device_worker_._set_program(self.program_)
|
|
|
|
|
self.device_worker_._gen_worker_desc(self.proto_desc)
|
|
|
|
|