【paddle.distributed.fleet】add data_generator in distributed.fleet.dataset (#27345)

my_2.0rc
yaoxuefeng 4 years ago committed by GitHub
parent aac57159c9
commit 780140599f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv", "InMemoryDataset", "QueueDataset"
"init_parallel_env",
"get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"InMemoryDataset",
"QueueDataset",
]
# collective apis

@ -18,6 +18,7 @@ from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet
from .base.util_factory import UtilBase
from .dataset import *
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
#from . import metrics
__all__ = [
@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker",
"PaddleCloudRoleMaker",
"Fleet",
"MultiSlotDataGenerator",
"MultiSlotStringDataGenerator",
"Role",
]

@ -0,0 +1,14 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .data_generator import *

@ -0,0 +1,39 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import paddle.distributed.fleet as fleet
class SyntheticData(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class SyntheticStringData(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])]
return data_iter
sd = SyntheticData()
sd.run_from_memory()
sd2 = SyntheticStringData()
sd2.run_from_memory()

@ -119,7 +119,7 @@ class DatasetBase(object):
def set_filelist(self, filelist):
"""
Set file list in current worker.
Set file list in current worker. The filelist is indicated by a list of file names (string).
Examples:
.. code-block:: python
@ -129,7 +129,7 @@ class DatasetBase(object):
dataset.set_filelist(['a.txt', 'b.txt'])
Args:
filelist(list): file list
filelist(list[str]): list of file names of inputs.
"""
self.dataset.set_filelist(filelist)
self.filelist = filelist
@ -240,6 +240,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase):
"""
:api_attr: Static Graph
InMemoryDataset, it will load data into memory
and shuffle data before training.
@ -265,6 +267,8 @@ class InMemoryDataset(DatasetBase):
def _init_distributed_settings(self, **kwargs):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize distributed-related setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
@ -323,6 +327,8 @@ class InMemoryDataset(DatasetBase):
def update_settings(self, **kwargs):
"""
:api_attr: Static Graph
should be called in user's python scripts to update setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs,
@ -400,6 +406,8 @@ class InMemoryDataset(DatasetBase):
def init(self, **kwargs):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
@ -450,11 +458,16 @@ class InMemoryDataset(DatasetBase):
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(),
dataset)
paddle.enable_static()
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace()
exe = paddle.static.Executor(place)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
"""
@ -639,6 +652,8 @@ class InMemoryDataset(DatasetBase):
def load_into_memory(self):
"""
:api_attr: Static Graph
Load data into memory
Examples:
@ -655,6 +670,8 @@ class InMemoryDataset(DatasetBase):
def preload_into_memory(self, thread_num=None):
"""
:api_attr: Static Graph
Load data into memory in async mode
Args:
@ -679,6 +696,8 @@ class InMemoryDataset(DatasetBase):
def wait_preload_done(self):
"""
:api_attr: Static Graph
Wait preload_into_memory done
Examples:
@ -696,6 +715,8 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self):
"""
:api_attr: Static Graph
Local shuffle
Examples:
@ -712,6 +733,8 @@ class InMemoryDataset(DatasetBase):
def global_shuffle(self, fleet=None, thread_num=12):
"""
:api_attr: Static Graph
Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple
processes on single machine or multiple machines training together.
@ -771,9 +794,11 @@ class InMemoryDataset(DatasetBase):
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.global_shuffle(fleet)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(), dataset)
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
dataset.release_memory()
"""
@ -781,6 +806,8 @@ class InMemoryDataset(DatasetBase):
def get_memory_data_size(self, fleet=None):
"""
:api_attr: Static Graph
Get memory data size, user can call this function to know the num
of ins in all workers after load into memory.
@ -817,6 +844,8 @@ class InMemoryDataset(DatasetBase):
def get_shuffle_data_size(self, fleet=None):
"""
:api_attr: Static Graph
Get shuffle data size, user can call this function to know the num
of ins in all workers after local/global shuffle.
@ -901,6 +930,8 @@ class InMemoryDataset(DatasetBase):
class QueueDataset(DatasetBase):
"""
:api_attr: Static Graph
QueueDataset, it will process data streamly.
Examples:
@ -920,6 +951,8 @@ class QueueDataset(DatasetBase):
def init(self, **kwargs):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance
"""
super(QueueDataset, self).init(**kwargs)

@ -16,6 +16,7 @@
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
from ..utils import deprecated
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
@ -335,6 +336,7 @@ class InMemoryDataset(DatasetBase):
dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
"""
@deprecated(since="2.0.0", update_to="paddle.distributed.InMemoryDataset")
def __init__(self):
""" Init. """
super(InMemoryDataset, self).__init__()
@ -350,12 +352,18 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_feed_type")
def set_feed_type(self, data_feed_type):
"""
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._prepare_to_run")
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
@ -376,16 +384,27 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_channel()
self.dataset.create_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_before_train"
)
def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, False)
self.dataset.dynamic_adjust_readers_num(thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_after_train"
)
def _dynamic_adjust_after_train(self):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
self.dataset.dynamic_adjust_readers_num(self.thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_queue_num")
def set_queue_num(self, queue_num):
"""
Set Dataset output queue num, training threads get data from queues
@ -404,6 +423,9 @@ class InMemoryDataset(DatasetBase):
self.is_user_set_queue_num = True
self.queue_num = queue_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_ins_id")
def set_parse_ins_id(self, parse_ins_id):
"""
Set id Dataset need to parse insid
@ -421,6 +443,9 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_ins_id = parse_ins_id
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_content")
def set_parse_content(self, parse_content):
"""
Set if Dataset need to parse content
@ -455,6 +480,9 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_logkey = parse_logkey
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
def set_merge_by_sid(self, merge_by_sid):
"""
Set if Dataset need to merge sid. If not, one ins means one Pv.
@ -544,6 +572,10 @@ class InMemoryDataset(DatasetBase):
"""
self.dataset.postprocess_instance()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_batch_size"
)
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
"""
Set fleet send batch size, default is 1024
@ -561,6 +593,10 @@ class InMemoryDataset(DatasetBase):
"""
self.fleet_send_batch_size = fleet_send_batch_size
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_sleep_seconds"
)
def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
"""
Set fleet send sleep time, default is 0
@ -578,6 +614,9 @@ class InMemoryDataset(DatasetBase):
"""
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_lineid")
def set_merge_by_lineid(self, merge_size=2):
"""
Set merge by line id, instances of same line id will be merged after
@ -598,16 +637,27 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True
self.parse_ins_id = True
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_generate_unique_feasigns"
)
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns
self.local_shard_num = shard_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._generate_local_tables_unlock"
)
def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num,
consume_thread_num, shard_num):
self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.load_into_memory")
def load_into_memory(self):
"""
Load data into memory
@ -624,6 +674,9 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run()
self.dataset.load_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.preload_into_memory")
def preload_into_memory(self, thread_num=None):
"""
Load data into memory in async mode
@ -648,6 +701,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_preload_readers()
self.dataset.preload_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.wait_preload_done")
def wait_preload_done(self):
"""
Wait preload_into_memory done
@ -665,6 +721,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.wait_preload_done()
self.dataset.destroy_preload_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.local_shuffle")
def local_shuffle(self):
"""
Local shuffle
@ -681,6 +740,9 @@ class InMemoryDataset(DatasetBase):
"""
self.dataset.local_shuffle()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.global_shuffle")
def global_shuffle(self, fleet=None, thread_num=12):
"""
Global shuffle.
@ -726,6 +788,9 @@ class InMemoryDataset(DatasetBase):
if fleet is not None:
fleet._role_maker.barrier_worker()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.release_memory")
def release_memory(self):
"""
:api_attr: Static Graph
@ -774,6 +839,9 @@ class InMemoryDataset(DatasetBase):
"""
return self.dataset.get_pv_data_size()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_memory_data_size")
def get_memory_data_size(self, fleet=None):
"""
Get memory data size, user can call this function to know the num
@ -810,6 +878,9 @@ class InMemoryDataset(DatasetBase):
return global_data_size[0]
return local_data_size[0]
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_shuffle_data_size")
def get_shuffle_data_size(self, fleet=None):
"""
Get shuffle data size, user can call this function to know the num
@ -869,6 +940,9 @@ class QueueDataset(DatasetBase):
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
@deprecated(
since="2.0.0",
update_to="paddle.distributed.QueueDataset._prepare_to_run")
def _prepare_to_run(self):
"""
Set data_feed_desc/thread num/filelist before run,

@ -19,7 +19,7 @@ import tarfile
import os
import paddle
import paddle.fluid.incubate.data_generator as data_generator
import paddle.distributed.fleet as fleet
from paddle.fluid.log_helper import get_logger
logger = get_logger(
@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def iter():
fs = line.strip().split('\t')

@ -22,7 +22,7 @@ import random
import warnings
import paddle
import paddle.fluid.incubate.data_generator as data_generator
import paddle.distributed.fleet as fleet
logging.basicConfig()
logger = logging.getLogger("paddle")
@ -84,7 +84,7 @@ class CtrReader(object):
return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def get_rand(low=0.0, high=1.0):
return random.random()

@ -0,0 +1,38 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
import re
import collections
import time
import paddle.distributed.fleet as fleet
class MyDataset(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
elements = line.strip().split()[0:]
output = [("show", [int(elements[0])]),
("click", [int(elements[1])]),
("slot1", [int(elements[2])])]
yield output
return data_iter
if __name__ == "__main__":
d = MyDataset()
d.run_from_stdin()

@ -21,13 +21,13 @@ import tarfile
import random
import paddle
import paddle.fluid.incubate.data_generator as data_generator
import paddle.distributed.fleet as fleet
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator):
class DatasetSimnetReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
pass

@ -0,0 +1,176 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import unittest
import paddle.distributed.fleet as fleet
import os
import sys
import platform
class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class MyMultiSlotStringDataGenerator(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", ["1", "2", "3", "4"]), ("label", ["0"])
return data_iter
class MyMultiSlotDataGenerator_error(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_2(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_3(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield (1, ["1", "2", "3", "4"]), (2, ["0"])
return data_iter
class MyMultiSlotDataGenerator_error_4(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", "1"), ("label", "0")
return data_iter
class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", []), ("label", [])
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_basic(self):
my_ms_dg = MyMultiSlotStringDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator_2(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_stdin(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip pipecommand UT on MacOS/Win")
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "2 1 2\n"
data += "2 6 2\n"
data += "2 5 2\n"
data += "2 7 2\n"
f.write(data)
tmp = os.popen(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).readlines()
expected_res = [
'1 2 1 1 1 2\n', '1 2 1 6 1 2\n', '1 2 1 5 1 2\n', '1 2 1 7 1 2\n'
]
self.assertEqual(tmp, expected_res)
os.remove("./test_queue_dataset_run_a.txt")
class TestMultiSlotDataGenerator_error(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_2(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_2()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_3(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_3()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_4(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_4()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_5()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
unittest.main()

@ -105,11 +105,15 @@ class TestDataset(unittest.TestCase):
dataset.load_into_memory()
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
exe.train_from_dataset(main_program, dataset)
except ImportError as e:
pass
except Exception as e:
@ -181,20 +185,24 @@ class TestDataset(unittest.TestCase):
use_var=slots_vars)
dataset.set_filelist([filename1, filename2])
dataset.load_into_memory()
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.run(startup_program)
if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(),
self.drop_last)
for i in range(self.epoch_num):
for data in data_loader():
exe.run(fluid.default_main_program(), feed=data)
exe.run(main_program, feed=data)
else:
for i in range(self.epoch_num):
try:
exe.train_from_dataset(fluid.default_main_program(),
dataset)
exe.train_from_dataset(main_program, dataset)
except Exception as e:
self.assertTrue(False)

@ -150,6 +150,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils',

Loading…
Cancel
Save