test=develop
mixed_precision_init
dongdaxiang 6 years ago
parent a659b37ace
commit 8257136012

@ -78,7 +78,7 @@ class MultiTrainer(TrainerDesc):
def _gen_trainer_desc(self):
super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self._device_worker._set_infer(self.infer_)
self._device_worker._set_infer(self._infer)
self._device_worker._gen_worker_desc(self.proto_desc)
@ -96,6 +96,6 @@ class DistMultiTrainer(TrainerDesc):
self.proto_desc.class_name = "DistMultiTrainer"
if self._program == None:
raise RuntimeError("None Program")
self._device_worker._set_infer(self.infer_)
self._device_worker._set_program(self.program_)
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)

Loading…
Cancel
Save