add set_trainer_num api in dataset (#29133)

musl/disable_test_yolov3_temporarily
Thunderbrook 5 years ago committed by GitHub
parent e03440812a
commit 4adddcc89a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase):
self.enable_pv_merge = False
self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
self.trainer_num = -1
@deprecated(
since="2.0.0",
@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_logkey = parse_logkey
def _set_trainer_num(self, trainer_num):
"""
Set trainer num
Args:
trainer_num(int): trainer num
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset._set_trainer_num(1)
"""
self.trainer_num = trainer_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase):
thread_num(int): shuffle thread num. Default is 12.
"""
trainer_num = 1
if fleet is not None:
fleet._role_maker.barrier_worker()
trainer_num = fleet.worker_num()
if self.trainer_num == -1:
self.trainer_num = fleet.worker_num()
if self.fleet_send_batch_size is None:
self.fleet_send_batch_size = 1024
if self.fleet_send_sleep_seconds is None:
self.fleet_send_sleep_seconds = 0
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num)
self.dataset.set_trainer_num(self.trainer_num)
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
if fleet is not None:

@ -65,8 +65,10 @@ class TestDataset(unittest.TestCase):
dataset = fluid.InMemoryDataset()
dataset.set_parse_ins_id(True)
dataset.set_parse_content(True)
dataset._set_trainer_num(1)
self.assertTrue(dataset.parse_ins_id)
self.assertTrue(dataset.parse_content)
self.assertEqual(dataset.trainer_num, 1)
def test_run_with_dump(self):
"""

Loading…
Cancel
Save