|
|
|
@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase):
|
|
|
|
|
self.enable_pv_merge = False
|
|
|
|
|
self.merge_by_lineid = False
|
|
|
|
|
self.fleet_send_sleep_seconds = None
|
|
|
|
|
self.trainer_num = -1
|
|
|
|
|
|
|
|
|
|
@deprecated(
|
|
|
|
|
since="2.0.0",
|
|
|
|
@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase):
|
|
|
|
|
"""
|
|
|
|
|
self.parse_logkey = parse_logkey
|
|
|
|
|
|
|
|
|
|
def _set_trainer_num(self, trainer_num):
|
|
|
|
|
"""
|
|
|
|
|
Set trainer num
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
trainer_num(int): trainer num
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
|
|
|
|
|
dataset._set_trainer_num(1)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self.trainer_num = trainer_num
|
|
|
|
|
|
|
|
|
|
@deprecated(
|
|
|
|
|
since="2.0.0",
|
|
|
|
|
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
|
|
|
|
@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase):
|
|
|
|
|
thread_num(int): shuffle thread num. Default is 12.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
trainer_num = 1
|
|
|
|
|
if fleet is not None:
|
|
|
|
|
fleet._role_maker.barrier_worker()
|
|
|
|
|
trainer_num = fleet.worker_num()
|
|
|
|
|
if self.trainer_num == -1:
|
|
|
|
|
self.trainer_num = fleet.worker_num()
|
|
|
|
|
if self.fleet_send_batch_size is None:
|
|
|
|
|
self.fleet_send_batch_size = 1024
|
|
|
|
|
if self.fleet_send_sleep_seconds is None:
|
|
|
|
|
self.fleet_send_sleep_seconds = 0
|
|
|
|
|
self.dataset.register_client2client_msg_handler()
|
|
|
|
|
self.dataset.set_trainer_num(trainer_num)
|
|
|
|
|
self.dataset.set_trainer_num(self.trainer_num)
|
|
|
|
|
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
|
|
|
|
|
self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
|
|
|
|
|
if fleet is not None:
|
|
|
|
|