【paddle.distributed.fleet】add data_generator in distributed.fleet.dataset (#27345)

my_2.0rc
yaoxuefeng 4 years ago committed by GitHub
parent aac57159c9
commit 780140599f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis # dygraph parallel apis
__all__ += [ __all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context", "init_parallel_env",
"ParallelEnv", "InMemoryDataset", "QueueDataset" "get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"InMemoryDataset",
"QueueDataset",
] ]
# collective apis # collective apis

@ -18,6 +18,7 @@ from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet from .base.fleet_base import Fleet
from .base.util_factory import UtilBase from .base.util_factory import UtilBase
from .dataset import * from .dataset import *
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
#from . import metrics #from . import metrics
__all__ = [ __all__ = [
@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker", "UserDefinedRoleMaker",
"PaddleCloudRoleMaker", "PaddleCloudRoleMaker",
"Fleet", "Fleet",
"MultiSlotDataGenerator",
"MultiSlotStringDataGenerator",
"Role", "Role",
] ]

@ -0,0 +1,14 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .data_generator import *

@ -0,0 +1,39 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import paddle.distributed.fleet as fleet
class SyntheticData(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class SyntheticStringData(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])]
return data_iter
sd = SyntheticData()
sd.run_from_memory()
sd2 = SyntheticStringData()
sd2.run_from_memory()

@ -119,7 +119,7 @@ class DatasetBase(object):
def set_filelist(self, filelist): def set_filelist(self, filelist):
""" """
Set file list in current worker. Set file list in current worker. The filelist is indicated by a list of file names (string).
Examples: Examples:
.. code-block:: python .. code-block:: python
@ -129,7 +129,7 @@ class DatasetBase(object):
dataset.set_filelist(['a.txt', 'b.txt']) dataset.set_filelist(['a.txt', 'b.txt'])
Args: Args:
filelist(list): file list filelist(list[str]): list of file names of inputs.
""" """
self.dataset.set_filelist(filelist) self.dataset.set_filelist(filelist)
self.filelist = filelist self.filelist = filelist
@ -240,6 +240,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
:api_attr: Static Graph
InMemoryDataset, it will load data into memory InMemoryDataset, it will load data into memory
and shuffle data before training. and shuffle data before training.
@ -265,6 +267,8 @@ class InMemoryDataset(DatasetBase):
def _init_distributed_settings(self, **kwargs): def _init_distributed_settings(self, **kwargs):
""" """
:api_attr: Static Graph
should be called only once in user's python scripts to initialize distributed-related setings of dataset instance should be called only once in user's python scripts to initialize distributed-related setings of dataset instance
Args: Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs: kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
@ -323,6 +327,8 @@ class InMemoryDataset(DatasetBase):
def update_settings(self, **kwargs): def update_settings(self, **kwargs):
""" """
:api_attr: Static Graph
should be called in user's python scripts to update setings of dataset instance should be called in user's python scripts to update setings of dataset instance
Args: Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs, kwargs: Keyword arguments. Currently, we support following keys in **kwargs,
@ -400,6 +406,8 @@ class InMemoryDataset(DatasetBase):
def init(self, **kwargs): def init(self, **kwargs):
""" """
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance should be called only once in user's python scripts to initialize setings of dataset instance
Args: Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs: kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
@ -450,11 +458,16 @@ class InMemoryDataset(DatasetBase):
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]) ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.load_into_memory() dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( paddle.enable_static()
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace()
exe.train_from_dataset(fluid.default_main_program(), exe = paddle.static.Executor(place)
dataset) startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
""" """
@ -639,6 +652,8 @@ class InMemoryDataset(DatasetBase):
def load_into_memory(self): def load_into_memory(self):
""" """
:api_attr: Static Graph
Load data into memory Load data into memory
Examples: Examples:
@ -655,6 +670,8 @@ class InMemoryDataset(DatasetBase):
def preload_into_memory(self, thread_num=None): def preload_into_memory(self, thread_num=None):
""" """
:api_attr: Static Graph
Load data into memory in async mode Load data into memory in async mode
Args: Args:
@ -679,6 +696,8 @@ class InMemoryDataset(DatasetBase):
def wait_preload_done(self): def wait_preload_done(self):
""" """
:api_attr: Static Graph
Wait preload_into_memory done Wait preload_into_memory done
Examples: Examples:
@ -696,6 +715,8 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self): def local_shuffle(self):
""" """
:api_attr: Static Graph
Local shuffle Local shuffle
Examples: Examples:
@ -712,6 +733,8 @@ class InMemoryDataset(DatasetBase):
def global_shuffle(self, fleet=None, thread_num=12): def global_shuffle(self, fleet=None, thread_num=12):
""" """
:api_attr: Static Graph
Global shuffle. Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple Global shuffle can be used only in distributed mode. i.e. multiple
processes on single machine or multiple machines training together. processes on single machine or multiple machines training together.
@ -771,9 +794,11 @@ class InMemoryDataset(DatasetBase):
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.load_into_memory() dataset.load_into_memory()
dataset.global_shuffle(fleet) dataset.global_shuffle(fleet)
exe = fluid.Executor(fluid.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(fluid.default_startup_program()) startup_program = paddle.static.Program()
exe.train_from_dataset(fluid.default_main_program(), dataset) main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
dataset.release_memory() dataset.release_memory()
""" """
@ -781,6 +806,8 @@ class InMemoryDataset(DatasetBase):
def get_memory_data_size(self, fleet=None): def get_memory_data_size(self, fleet=None):
""" """
:api_attr: Static Graph
Get memory data size, user can call this function to know the num Get memory data size, user can call this function to know the num
of ins in all workers after load into memory. of ins in all workers after load into memory.
@ -817,6 +844,8 @@ class InMemoryDataset(DatasetBase):
def get_shuffle_data_size(self, fleet=None): def get_shuffle_data_size(self, fleet=None):
""" """
:api_attr: Static Graph
Get shuffle data size, user can call this function to know the num Get shuffle data size, user can call this function to know the num
of ins in all workers after local/global shuffle. of ins in all workers after local/global shuffle.
@ -901,6 +930,8 @@ class InMemoryDataset(DatasetBase):
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
""" """
:api_attr: Static Graph
QueueDataset, it will process data streamly. QueueDataset, it will process data streamly.
Examples: Examples:
@ -920,6 +951,8 @@ class QueueDataset(DatasetBase):
def init(self, **kwargs): def init(self, **kwargs):
""" """
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance should be called only once in user's python scripts to initialize setings of dataset instance
""" """
super(QueueDataset, self).init(**kwargs) super(QueueDataset, self).init(**kwargs)

@ -16,6 +16,7 @@
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import core from . import core
from ..utils import deprecated
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset'] __all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
@ -335,6 +336,7 @@ class InMemoryDataset(DatasetBase):
dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
""" """
@deprecated(since="2.0.0", update_to="paddle.distributed.InMemoryDataset")
def __init__(self): def __init__(self):
""" Init. """ """ Init. """
super(InMemoryDataset, self).__init__() super(InMemoryDataset, self).__init__()
@ -350,12 +352,18 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = False self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None self.fleet_send_sleep_seconds = None
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_feed_type")
def set_feed_type(self, data_feed_type): def set_feed_type(self, data_feed_type):
""" """
Set data_feed_desc Set data_feed_desc
""" """
self.proto_desc.name = data_feed_type self.proto_desc.name = data_feed_type
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._prepare_to_run")
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc before load or shuffle, Set data_feed_desc before load or shuffle,
@ -376,16 +384,27 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_channel() self.dataset.create_channel()
self.dataset.create_readers() self.dataset.create_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_before_train"
)
def _dynamic_adjust_before_train(self, thread_num): def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num: if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, False) self.dataset.dynamic_adjust_channel_num(thread_num, False)
self.dataset.dynamic_adjust_readers_num(thread_num) self.dataset.dynamic_adjust_readers_num(thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_after_train"
)
def _dynamic_adjust_after_train(self): def _dynamic_adjust_after_train(self):
if not self.is_user_set_queue_num: if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(self.thread_num, False) self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
self.dataset.dynamic_adjust_readers_num(self.thread_num) self.dataset.dynamic_adjust_readers_num(self.thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_queue_num")
def set_queue_num(self, queue_num): def set_queue_num(self, queue_num):
""" """
Set Dataset output queue num, training threads get data from queues Set Dataset output queue num, training threads get data from queues
@ -404,6 +423,9 @@ class InMemoryDataset(DatasetBase):
self.is_user_set_queue_num = True self.is_user_set_queue_num = True
self.queue_num = queue_num self.queue_num = queue_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_ins_id")
def set_parse_ins_id(self, parse_ins_id): def set_parse_ins_id(self, parse_ins_id):
""" """
Set id Dataset need to parse insid Set id Dataset need to parse insid
@ -421,6 +443,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_ins_id = parse_ins_id self.parse_ins_id = parse_ins_id
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_content")
def set_parse_content(self, parse_content): def set_parse_content(self, parse_content):
""" """
Set if Dataset need to parse content Set if Dataset need to parse content
@ -455,6 +480,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_logkey = parse_logkey self.parse_logkey = parse_logkey
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
def set_merge_by_sid(self, merge_by_sid): def set_merge_by_sid(self, merge_by_sid):
""" """
Set if Dataset need to merge sid. If not, one ins means one Pv. Set if Dataset need to merge sid. If not, one ins means one Pv.
@ -544,6 +572,10 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.postprocess_instance() self.dataset.postprocess_instance()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_batch_size"
)
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024): def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
""" """
Set fleet send batch size, default is 1024 Set fleet send batch size, default is 1024
@ -561,6 +593,10 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_batch_size = fleet_send_batch_size self.fleet_send_batch_size = fleet_send_batch_size
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_sleep_seconds"
)
def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0): def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
""" """
Set fleet send sleep time, default is 0 Set fleet send sleep time, default is 0
@ -578,6 +614,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_lineid")
def set_merge_by_lineid(self, merge_size=2): def set_merge_by_lineid(self, merge_size=2):
""" """
Set merge by line id, instances of same line id will be merged after Set merge by line id, instances of same line id will be merged after
@ -598,16 +637,27 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True self.merge_by_lineid = True
self.parse_ins_id = True self.parse_ins_id = True
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_generate_unique_feasigns"
)
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num): def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns) self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns self.gen_uni_feasigns = generate_uni_feasigns
self.local_shard_num = shard_num self.local_shard_num = shard_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._generate_local_tables_unlock"
)
def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num, def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num,
consume_thread_num, shard_num): consume_thread_num, shard_num):
self.dataset.generate_local_tables_unlock( self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.load_into_memory")
def load_into_memory(self): def load_into_memory(self):
""" """
Load data into memory Load data into memory
@ -624,6 +674,9 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run() self._prepare_to_run()
self.dataset.load_into_memory() self.dataset.load_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.preload_into_memory")
def preload_into_memory(self, thread_num=None): def preload_into_memory(self, thread_num=None):
""" """
Load data into memory in async mode Load data into memory in async mode
@ -648,6 +701,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_preload_readers() self.dataset.create_preload_readers()
self.dataset.preload_into_memory() self.dataset.preload_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.wait_preload_done")
def wait_preload_done(self): def wait_preload_done(self):
""" """
Wait preload_into_memory done Wait preload_into_memory done
@ -665,6 +721,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.wait_preload_done() self.dataset.wait_preload_done()
self.dataset.destroy_preload_readers() self.dataset.destroy_preload_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.local_shuffle")
def local_shuffle(self): def local_shuffle(self):
""" """
Local shuffle Local shuffle
@ -681,6 +740,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.local_shuffle() self.dataset.local_shuffle()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.global_shuffle")
def global_shuffle(self, fleet=None, thread_num=12): def global_shuffle(self, fleet=None, thread_num=12):
""" """
Global shuffle. Global shuffle.
@ -726,6 +788,9 @@ class InMemoryDataset(DatasetBase):
if fleet is not None: if fleet is not None:
fleet._role_maker.barrier_worker() fleet._role_maker.barrier_worker()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.release_memory")
def release_memory(self): def release_memory(self):
""" """
:api_attr: Static Graph :api_attr: Static Graph
@ -774,6 +839,9 @@ class InMemoryDataset(DatasetBase):
""" """
return self.dataset.get_pv_data_size() return self.dataset.get_pv_data_size()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_memory_data_size")
def get_memory_data_size(self, fleet=None): def get_memory_data_size(self, fleet=None):
""" """
Get memory data size, user can call this function to know the num Get memory data size, user can call this function to know the num
@ -810,6 +878,9 @@ class InMemoryDataset(DatasetBase):
return global_data_size[0] return global_data_size[0]
return local_data_size[0] return local_data_size[0]
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_shuffle_data_size")
def get_shuffle_data_size(self, fleet=None): def get_shuffle_data_size(self, fleet=None):
""" """
Get shuffle data size, user can call this function to know the num Get shuffle data size, user can call this function to know the num
@ -869,6 +940,9 @@ class QueueDataset(DatasetBase):
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
@deprecated(
since="2.0.0",
update_to="paddle.distributed.QueueDataset._prepare_to_run")
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc/thread num/filelist before run, Set data_feed_desc/thread num/filelist before run,

@ -19,7 +19,7 @@ import tarfile
import os import os
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.distributed.fleet as fleet
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
logger = get_logger( logger = get_logger(
@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return res return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator): class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def iter(): def iter():
fs = line.strip().split('\t') fs = line.strip().split('\t')

@ -22,7 +22,7 @@ import random
import warnings import warnings
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.distributed.fleet as fleet
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
@ -84,7 +84,7 @@ class CtrReader(object):
return reader return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator): class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def get_rand(low=0.0, high=1.0): def get_rand(low=0.0, high=1.0):
return random.random() return random.random()

@ -0,0 +1,38 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
import re
import collections
import time
import paddle.distributed.fleet as fleet
class MyDataset(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
elements = line.strip().split()[0:]
output = [("show", [int(elements[0])]),
("click", [int(elements[1])]),
("slot1", [int(elements[2])])]
yield output
return data_iter
if __name__ == "__main__":
d = MyDataset()
d.run_from_stdin()

@ -21,13 +21,13 @@ import tarfile
import random import random
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.distributed.fleet as fleet
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator): class DatasetSimnetReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
pass pass

@ -0,0 +1,176 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import unittest
import paddle.distributed.fleet as fleet
import os
import sys
import platform
class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class MyMultiSlotStringDataGenerator(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", ["1", "2", "3", "4"]), ("label", ["0"])
return data_iter
class MyMultiSlotDataGenerator_error(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_2(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_3(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield (1, ["1", "2", "3", "4"]), (2, ["0"])
return data_iter
class MyMultiSlotDataGenerator_error_4(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", "1"), ("label", "0")
return data_iter
class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", []), ("label", [])
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_basic(self):
my_ms_dg = MyMultiSlotStringDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator_2(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_stdin(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip pipecommand UT on MacOS/Win")
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "2 1 2\n"
data += "2 6 2\n"
data += "2 5 2\n"
data += "2 7 2\n"
f.write(data)
tmp = os.popen(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).readlines()
expected_res = [
'1 2 1 1 1 2\n', '1 2 1 6 1 2\n', '1 2 1 5 1 2\n', '1 2 1 7 1 2\n'
]
self.assertEqual(tmp, expected_res)
os.remove("./test_queue_dataset_run_a.txt")
class TestMultiSlotDataGenerator_error(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_2(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_2()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_3(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_3()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_4(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_4()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_5()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
unittest.main()

@ -105,11 +105,15 @@ class TestDataset(unittest.TestCase):
dataset.load_into_memory() dataset.load_into_memory()
dataset.local_shuffle() dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace()) paddle.enable_static()
exe.run(fluid.default_startup_program())
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
for i in range(2): for i in range(2):
try: try:
exe.train_from_dataset(fluid.default_main_program(), dataset) exe.train_from_dataset(main_program, dataset)
except ImportError as e: except ImportError as e:
pass pass
except Exception as e: except Exception as e:
@ -181,20 +185,24 @@ class TestDataset(unittest.TestCase):
use_var=slots_vars) use_var=slots_vars)
dataset.set_filelist([filename1, filename2]) dataset.set_filelist([filename1, filename2])
dataset.load_into_memory() dataset.load_into_memory()
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(startup_program)
if self.use_data_loader: if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset, data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(), fluid.cpu_places(),
self.drop_last) self.drop_last)
for i in range(self.epoch_num): for i in range(self.epoch_num):
for data in data_loader(): for data in data_loader():
exe.run(fluid.default_main_program(), feed=data) exe.run(main_program, feed=data)
else: else:
for i in range(self.epoch_num): for i in range(self.epoch_num):
try: try:
exe.train_from_dataset(fluid.default_main_program(), exe.train_from_dataset(main_program, dataset)
dataset)
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)

@ -150,6 +150,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.runtime', 'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset', 'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
'paddle.distributed.fleet.metrics', 'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto', 'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils', 'paddle.distributed.fleet.utils',

Loading…
Cancel
Save