|
|
|
@ -78,7 +78,7 @@ class MultiTrainer(TrainerDesc):
|
|
|
|
|
def _gen_trainer_desc(self):
|
|
|
|
|
super(MultiTrainer, self)._gen_trainer_desc()
|
|
|
|
|
self.proto_desc.class_name = "MultiTrainer"
|
|
|
|
|
self._device_worker._set_infer(self.infer_)
|
|
|
|
|
self._device_worker._set_infer(self._infer)
|
|
|
|
|
self._device_worker._gen_worker_desc(self.proto_desc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -96,6 +96,6 @@ class DistMultiTrainer(TrainerDesc):
|
|
|
|
|
self.proto_desc.class_name = "DistMultiTrainer"
|
|
|
|
|
if self._program == None:
|
|
|
|
|
raise RuntimeError("None Program")
|
|
|
|
|
self._device_worker._set_infer(self.infer_)
|
|
|
|
|
self._device_worker._set_program(self.program_)
|
|
|
|
|
self._device_worker._set_infer(self._infer)
|
|
|
|
|
self._device_worker._set_program(self._program)
|
|
|
|
|
self._device_worker._gen_worker_desc(self.proto_desc)
|
|
|
|
|